mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
2266 lines
69 KiB
Python
2266 lines
69 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from mistral_common.exceptions import InvalidMessageStructureException
|
|
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
|
|
|
from vllm.tokenizers.mistral import (
|
|
MistralTokenizer,
|
|
_prepare_apply_chat_template_tools_and_messages,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"openai_request,expected_mistral_output",
|
|
[
|
|
(
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "What is the current local date and time?",
|
|
}
|
|
],
|
|
"tools": [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"description": "Fetch the current local date and time.",
|
|
"name": "get_current_time",
|
|
},
|
|
}
|
|
],
|
|
},
|
|
(
|
|
[
|
|
{
|
|
"role": "user",
|
|
"content": "What is the current local date and time?",
|
|
}
|
|
],
|
|
[
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"description": "Fetch the current local date and time.",
|
|
"name": "get_current_time",
|
|
"parameters": {},
|
|
},
|
|
}
|
|
],
|
|
),
|
|
),
|
|
(
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "What is the current local date and time?",
|
|
}
|
|
],
|
|
"tools": [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"description": "Fetch the current local date and time.",
|
|
"name": "get_current_time",
|
|
"parameters": {},
|
|
},
|
|
}
|
|
],
|
|
},
|
|
(
|
|
[
|
|
{
|
|
"role": "user",
|
|
"content": "What is the current local date and time?",
|
|
}
|
|
],
|
|
[
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"description": "Fetch the current local date and time.",
|
|
"name": "get_current_time",
|
|
"parameters": {},
|
|
},
|
|
}
|
|
],
|
|
),
|
|
),
|
|
],
|
|
)
|
|
def test_prepare_apply_chat_template_tools_and_messages(
|
|
openai_request, expected_mistral_output
|
|
):
|
|
actual_request = _prepare_apply_chat_template_tools_and_messages(
|
|
openai_request["messages"], openai_request["tools"]
|
|
)
|
|
assert actual_request == expected_mistral_output
|
|
|
|
|
|
# Tool use with list content and reasoning
|
|
@pytest.mark.parametrize(
|
|
"openai_request,expected_mistral_output",
|
|
[
|
|
(
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "What's the weather in Paris?",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"reasoning": None,
|
|
"content": None,
|
|
"tool_calls": [
|
|
{
|
|
"id": "call123",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"arguments": '{"city": "Paris"}',
|
|
},
|
|
}
|
|
],
|
|
},
|
|
{
|
|
"role": "tool",
|
|
"content": [{"type": "text", "text": "Rainy"}],
|
|
"name": "get_weather",
|
|
"tool_call_id": "call123",
|
|
},
|
|
],
|
|
"tools": [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "Gets the current weather in a city.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"city": {
|
|
"type": "string",
|
|
"description": "The city name",
|
|
}
|
|
},
|
|
"required": ["city"],
|
|
},
|
|
},
|
|
}
|
|
],
|
|
},
|
|
(
|
|
[
|
|
{
|
|
"role": "user",
|
|
"content": "What's the weather in Paris?",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": None,
|
|
"tool_calls": [
|
|
{
|
|
"id": "call123",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"arguments": '{"city": "Paris"}',
|
|
},
|
|
}
|
|
],
|
|
},
|
|
{
|
|
"role": "tool",
|
|
"content": [{"type": "text", "text": "Rainy"}],
|
|
"name": "get_weather",
|
|
"tool_call_id": "call123",
|
|
},
|
|
],
|
|
[
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "Gets the current weather in a city.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"city": {
|
|
"type": "string",
|
|
"description": "The city name",
|
|
}
|
|
},
|
|
"required": ["city"],
|
|
},
|
|
},
|
|
}
|
|
],
|
|
),
|
|
)
|
|
],
|
|
)
|
|
def test_prepare_apply_chat_template_tools_and_messages_list_content(
|
|
openai_request, expected_mistral_output
|
|
):
|
|
actual_request = _prepare_apply_chat_template_tools_and_messages(
|
|
openai_request["messages"], openai_request["tools"]
|
|
)
|
|
assert actual_request == expected_mistral_output
|
|
|
|
|
|
def test_prepare_apply_chat_template_generation_prompt_and_continue():
|
|
messages = [{"role": "assistant", "content": "Hello"}]
|
|
tools: list[dict[str, Any]] = []
|
|
with pytest.raises(ValueError):
|
|
_prepare_apply_chat_template_tools_and_messages(
|
|
messages, tools, add_generation_prompt=True
|
|
)
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
out_messages, _ = _prepare_apply_chat_template_tools_and_messages(
|
|
messages, tools, add_generation_prompt=True
|
|
)
|
|
assert out_messages == [{"role": "user", "content": "Hello"}]
|
|
|
|
with pytest.raises(ValueError):
|
|
_prepare_apply_chat_template_tools_and_messages(
|
|
messages, tools, add_generation_prompt=True, continue_final_message=True
|
|
)
|
|
|
|
messages = [{"role": "assistant", "content": "Hello"}]
|
|
out_messages, _ = _prepare_apply_chat_template_tools_and_messages(
|
|
messages, tools, add_generation_prompt=False, continue_final_message=True
|
|
)
|
|
assert out_messages == [{"role": "assistant", "content": "Hello"}]
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
with pytest.raises(ValueError):
|
|
_prepare_apply_chat_template_tools_and_messages(
|
|
messages, tools, add_generation_prompt=False, continue_final_message=True
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def mistral_tokenizer(request) -> MistralTokenizer:
|
|
return MistralTokenizer.from_pretrained(request.param)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"mistral_tokenizer",
|
|
["mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Magistral-Small-2509"],
|
|
indirect=True,
|
|
)
|
|
class TestMistralTokenizer:
|
|
def test_all_special_tokens(self, mistral_tokenizer: MistralTokenizer):
|
|
if mistral_tokenizer.is_tekken:
|
|
assert mistral_tokenizer.all_special_tokens == [
|
|
"<unk>",
|
|
"<s>",
|
|
"</s>",
|
|
"[INST]",
|
|
"[/INST]",
|
|
"[AVAILABLE_TOOLS]",
|
|
"[/AVAILABLE_TOOLS]",
|
|
"[TOOL_RESULTS]",
|
|
"[/TOOL_RESULTS]",
|
|
"[TOOL_CALLS]",
|
|
"[IMG]",
|
|
"<pad>",
|
|
"[IMG_BREAK]",
|
|
"[IMG_END]",
|
|
"[PREFIX]",
|
|
"[MIDDLE]",
|
|
"[SUFFIX]",
|
|
"[SYSTEM_PROMPT]",
|
|
"[/SYSTEM_PROMPT]",
|
|
"[TOOL_CONTENT]",
|
|
] + [f"<SPECIAL_{i}>" for i in range(20, 32)] + [
|
|
"[ARGS]",
|
|
"[CALL_ID]",
|
|
"[THINK]",
|
|
"[/THINK]",
|
|
] + [f"<SPECIAL_{i}>" for i in range(36, 1000)]
|
|
else:
|
|
assert mistral_tokenizer.all_special_tokens == [
|
|
"<s>",
|
|
"</s>",
|
|
"[INST]",
|
|
"[/INST]",
|
|
"[TOOL_CALLS]",
|
|
"[AVAILABLE_TOOLS]",
|
|
"[/AVAILABLE_TOOLS]",
|
|
"[TOOL_RESULTS]",
|
|
"[/TOOL_RESULTS]",
|
|
] + [f"[control_{i}]" for i in range(8, 769)]
|
|
|
|
def get_vocab(self, mistral_tokenizer: MistralTokenizer):
|
|
assert (
|
|
mistral_tokenizer.get_vocab()
|
|
== mistral_tokenizer.transformers_tokenizer.get_vocab()
|
|
)
|
|
|
|
def test_get_added_vocab(self, mistral_tokenizer: MistralTokenizer):
|
|
assert mistral_tokenizer.get_added_vocab() == {}
|
|
|
|
def test_encode(self, mistral_tokenizer: MistralTokenizer):
|
|
token_ids = (
|
|
[1, 22177, 4304, 2662]
|
|
if mistral_tokenizer.is_tekken
|
|
else [1, 23325, 2294, 1686]
|
|
)
|
|
|
|
assert mistral_tokenizer.encode("Hello world !") == token_ids
|
|
assert mistral_tokenizer.encode("Hello world !", max_length=3) == token_ids[:-1]
|
|
assert (
|
|
mistral_tokenizer.encode("Hello world !", truncation=True, max_length=3)
|
|
== token_ids[:-1]
|
|
)
|
|
assert (
|
|
mistral_tokenizer.encode("Hello world !", truncation=False, max_length=3)
|
|
== token_ids
|
|
)
|
|
|
|
assert (
|
|
mistral_tokenizer.encode("Hello world !", add_special_tokens=True)
|
|
== token_ids
|
|
)
|
|
assert (
|
|
mistral_tokenizer.encode(
|
|
"Hello world !", add_special_tokens=True, max_length=3
|
|
)
|
|
== token_ids[:-1]
|
|
)
|
|
assert (
|
|
mistral_tokenizer.encode(
|
|
"Hello world !", add_special_tokens=True, truncation=False, max_length=3
|
|
)
|
|
== token_ids
|
|
)
|
|
assert (
|
|
mistral_tokenizer.encode("Hello world !", add_special_tokens=False)
|
|
== token_ids[1:]
|
|
)
|
|
assert mistral_tokenizer.encode("", add_special_tokens=False) == []
|
|
|
|
def test_call(self, mistral_tokenizer: MistralTokenizer):
|
|
token_ids = (
|
|
[1, 22177, 4304, 2662]
|
|
if mistral_tokenizer.is_tekken
|
|
else [1, 23325, 2294, 1686]
|
|
)
|
|
attn_mask = [1 for _ in range(len(token_ids))]
|
|
|
|
# Test 1: no special tokens
|
|
assert mistral_tokenizer("Hello world !", add_special_tokens=False) == {
|
|
"attention_mask": attn_mask[1:],
|
|
"input_ids": token_ids[1:],
|
|
}
|
|
# Test 2: special tokens
|
|
assert mistral_tokenizer("Hello world !", add_special_tokens=True) == {
|
|
"attention_mask": attn_mask,
|
|
"input_ids": token_ids,
|
|
}
|
|
# Test 3: special tokens + truncation
|
|
assert mistral_tokenizer(
|
|
"Hello world !", add_special_tokens=True, truncation=True, max_length=3
|
|
) == {
|
|
"attention_mask": attn_mask[:-1],
|
|
"input_ids": token_ids[:-1],
|
|
}
|
|
# Test 4: special tokens + no truncation + max length
|
|
assert mistral_tokenizer(
|
|
"Hello world !", add_special_tokens=True, max_length=3
|
|
) == {
|
|
"attention_mask": attn_mask,
|
|
"input_ids": token_ids,
|
|
}
|
|
# Test 5: empty string
|
|
assert mistral_tokenizer("", add_special_tokens=False) == {
|
|
"attention_mask": [],
|
|
"input_ids": [],
|
|
}
|
|
|
|
with pytest.raises(
|
|
ValueError,
|
|
match=(r"`text_pair` is not supported by `MistralTokenizer.__call__`."),
|
|
):
|
|
mistral_tokenizer("Hello world !", "invalid pair")
|
|
|
|
@pytest.mark.parametrize(
|
|
"openai_request,add_generation_prompt,continue_final_message,expected_output,decoded_expected_output",
|
|
[
|
|
(
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "Hello world !",
|
|
}
|
|
],
|
|
},
|
|
True,
|
|
False,
|
|
([1, 3, 23325, 2294, 1686, 4], [1, 3, 22177, 4304, 2662, 4]),
|
|
("<s>[INST]▁Hello▁world▁![/INST]", ("<s>[INST]Hello world ![/INST]")),
|
|
),
|
|
(
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "I am an AI",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Hello world !",
|
|
},
|
|
],
|
|
},
|
|
True,
|
|
False,
|
|
(
|
|
[1, 3, 1083, 1605, 1164, 16875, 781, 781, 16998, 2294, 1686, 4],
|
|
[1, 17, 1073, 1855, 1420, 26554, 18, 3, 22177, 4304, 2662, 4],
|
|
),
|
|
(
|
|
"<s>[INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]",
|
|
(
|
|
"<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][INST]Hello world ![/INST]" # noqa: E501
|
|
),
|
|
),
|
|
),
|
|
(
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "I am an AI",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Hello world !",
|
|
},
|
|
],
|
|
"tools": [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "Gets the current weather in a city.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"city": {
|
|
"type": "string",
|
|
"description": "The city name",
|
|
}
|
|
},
|
|
"required": ["city"],
|
|
},
|
|
},
|
|
}
|
|
],
|
|
},
|
|
True,
|
|
False,
|
|
(
|
|
[
|
|
1,
|
|
6,
|
|
1501,
|
|
7567,
|
|
1891,
|
|
2032,
|
|
1113,
|
|
3396,
|
|
1316,
|
|
1113,
|
|
3396,
|
|
2032,
|
|
10598,
|
|
1629,
|
|
2032,
|
|
1113,
|
|
1295,
|
|
29498,
|
|
1537,
|
|
1991,
|
|
1316,
|
|
1113,
|
|
7286,
|
|
2032,
|
|
1113,
|
|
2226,
|
|
29481,
|
|
1040,
|
|
2636,
|
|
8854,
|
|
1065,
|
|
1032,
|
|
3758,
|
|
9959,
|
|
1113,
|
|
12206,
|
|
2032,
|
|
10598,
|
|
1891,
|
|
2032,
|
|
1113,
|
|
3582,
|
|
1316,
|
|
1113,
|
|
11491,
|
|
2032,
|
|
10598,
|
|
19141,
|
|
2032,
|
|
10598,
|
|
1891,
|
|
2032,
|
|
1113,
|
|
2195,
|
|
1316,
|
|
1113,
|
|
7286,
|
|
2032,
|
|
1113,
|
|
1782,
|
|
3758,
|
|
1909,
|
|
29507,
|
|
11549,
|
|
1113,
|
|
11661,
|
|
2032,
|
|
8135,
|
|
19141,
|
|
3010,
|
|
1743,
|
|
10925,
|
|
7,
|
|
3,
|
|
1083,
|
|
1605,
|
|
1164,
|
|
16875,
|
|
781,
|
|
781,
|
|
16998,
|
|
2294,
|
|
1686,
|
|
4,
|
|
],
|
|
[
|
|
1,
|
|
17,
|
|
1073,
|
|
1855,
|
|
1420,
|
|
26554,
|
|
18,
|
|
5,
|
|
1091,
|
|
19227,
|
|
4994,
|
|
2811,
|
|
1429,
|
|
5165,
|
|
1897,
|
|
1429,
|
|
5165,
|
|
2811,
|
|
16753,
|
|
2391,
|
|
2811,
|
|
1429,
|
|
1689,
|
|
1095,
|
|
45629,
|
|
1897,
|
|
1429,
|
|
14653,
|
|
2811,
|
|
1429,
|
|
1071,
|
|
3083,
|
|
1278,
|
|
3519,
|
|
17253,
|
|
1294,
|
|
1261,
|
|
5970,
|
|
39249,
|
|
1429,
|
|
26204,
|
|
2811,
|
|
16753,
|
|
4994,
|
|
2811,
|
|
1429,
|
|
6371,
|
|
1897,
|
|
1429,
|
|
48649,
|
|
2811,
|
|
16753,
|
|
29363,
|
|
2811,
|
|
16753,
|
|
4994,
|
|
2811,
|
|
1429,
|
|
3607,
|
|
1897,
|
|
1429,
|
|
14653,
|
|
2811,
|
|
1429,
|
|
1784,
|
|
5970,
|
|
2564,
|
|
1034,
|
|
47579,
|
|
1429,
|
|
15760,
|
|
2811,
|
|
12161,
|
|
29363,
|
|
4964,
|
|
2821,
|
|
27028,
|
|
6,
|
|
3,
|
|
22177,
|
|
4304,
|
|
2662,
|
|
4,
|
|
],
|
|
),
|
|
(
|
|
'<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]',
|
|
(
|
|
'<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST]' # noqa: E501
|
|
),
|
|
),
|
|
),
|
|
(
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "I am an AI",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Hello world !",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [
|
|
{
|
|
"id": "123456789",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"arguments": '{"city": "Paris"}',
|
|
},
|
|
}
|
|
],
|
|
},
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": "123456789",
|
|
"content": '{"temperature": 20, "unit": "celsius"}',
|
|
},
|
|
],
|
|
"tools": [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "Gets the current weather in a city.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"city": {
|
|
"type": "string",
|
|
"description": "The city name",
|
|
}
|
|
},
|
|
"required": ["city"],
|
|
},
|
|
},
|
|
}
|
|
],
|
|
},
|
|
True,
|
|
False,
|
|
(
|
|
[
|
|
1,
|
|
6,
|
|
1501,
|
|
7567,
|
|
1891,
|
|
2032,
|
|
1113,
|
|
3396,
|
|
1316,
|
|
1113,
|
|
3396,
|
|
2032,
|
|
10598,
|
|
1629,
|
|
2032,
|
|
1113,
|
|
1295,
|
|
29498,
|
|
1537,
|
|
1991,
|
|
1316,
|
|
1113,
|
|
7286,
|
|
2032,
|
|
1113,
|
|
2226,
|
|
29481,
|
|
1040,
|
|
2636,
|
|
8854,
|
|
1065,
|
|
1032,
|
|
3758,
|
|
9959,
|
|
1113,
|
|
12206,
|
|
2032,
|
|
10598,
|
|
1891,
|
|
2032,
|
|
1113,
|
|
3582,
|
|
1316,
|
|
1113,
|
|
11491,
|
|
2032,
|
|
10598,
|
|
19141,
|
|
2032,
|
|
10598,
|
|
1891,
|
|
2032,
|
|
1113,
|
|
2195,
|
|
1316,
|
|
1113,
|
|
7286,
|
|
2032,
|
|
1113,
|
|
1782,
|
|
3758,
|
|
1909,
|
|
29507,
|
|
11549,
|
|
1113,
|
|
11661,
|
|
2032,
|
|
8135,
|
|
19141,
|
|
3010,
|
|
1743,
|
|
10925,
|
|
7,
|
|
3,
|
|
1083,
|
|
1605,
|
|
1164,
|
|
16875,
|
|
781,
|
|
781,
|
|
16998,
|
|
2294,
|
|
1686,
|
|
4,
|
|
5,
|
|
1501,
|
|
7567,
|
|
1629,
|
|
2032,
|
|
1113,
|
|
1295,
|
|
29498,
|
|
1537,
|
|
1991,
|
|
1316,
|
|
1113,
|
|
17452,
|
|
2032,
|
|
10598,
|
|
19141,
|
|
2032,
|
|
1113,
|
|
4684,
|
|
1046,
|
|
8474,
|
|
1113,
|
|
1081,
|
|
2032,
|
|
1113,
|
|
29508,
|
|
29518,
|
|
29538,
|
|
29549,
|
|
29550,
|
|
29552,
|
|
29555,
|
|
29551,
|
|
29542,
|
|
29507,
|
|
10925,
|
|
2,
|
|
8,
|
|
10598,
|
|
4557,
|
|
2032,
|
|
10598,
|
|
29475,
|
|
17329,
|
|
2032,
|
|
29473,
|
|
29518,
|
|
29502,
|
|
29493,
|
|
1113,
|
|
6074,
|
|
2032,
|
|
1113,
|
|
29485,
|
|
1958,
|
|
3938,
|
|
8474,
|
|
1113,
|
|
3613,
|
|
29498,
|
|
1081,
|
|
2032,
|
|
1113,
|
|
29508,
|
|
29518,
|
|
29538,
|
|
29549,
|
|
29550,
|
|
29552,
|
|
29555,
|
|
29551,
|
|
29542,
|
|
18163,
|
|
9,
|
|
],
|
|
[
|
|
1,
|
|
17,
|
|
1073,
|
|
1855,
|
|
1420,
|
|
26554,
|
|
18,
|
|
5,
|
|
1091,
|
|
19227,
|
|
4994,
|
|
2811,
|
|
1429,
|
|
5165,
|
|
1897,
|
|
1429,
|
|
5165,
|
|
2811,
|
|
16753,
|
|
2391,
|
|
2811,
|
|
1429,
|
|
1689,
|
|
1095,
|
|
45629,
|
|
1897,
|
|
1429,
|
|
14653,
|
|
2811,
|
|
1429,
|
|
1071,
|
|
3083,
|
|
1278,
|
|
3519,
|
|
17253,
|
|
1294,
|
|
1261,
|
|
5970,
|
|
39249,
|
|
1429,
|
|
26204,
|
|
2811,
|
|
16753,
|
|
4994,
|
|
2811,
|
|
1429,
|
|
6371,
|
|
1897,
|
|
1429,
|
|
48649,
|
|
2811,
|
|
16753,
|
|
29363,
|
|
2811,
|
|
16753,
|
|
4994,
|
|
2811,
|
|
1429,
|
|
3607,
|
|
1897,
|
|
1429,
|
|
14653,
|
|
2811,
|
|
1429,
|
|
1784,
|
|
5970,
|
|
2564,
|
|
1034,
|
|
47579,
|
|
1429,
|
|
15760,
|
|
2811,
|
|
12161,
|
|
29363,
|
|
4964,
|
|
2821,
|
|
27028,
|
|
6,
|
|
3,
|
|
22177,
|
|
4304,
|
|
2662,
|
|
4,
|
|
9,
|
|
1689,
|
|
1095,
|
|
45629,
|
|
32,
|
|
19227,
|
|
29363,
|
|
2811,
|
|
1429,
|
|
42572,
|
|
46005,
|
|
2,
|
|
7,
|
|
19227,
|
|
113824,
|
|
2811,
|
|
1032,
|
|
1050,
|
|
1048,
|
|
1044,
|
|
1429,
|
|
8979,
|
|
2811,
|
|
1429,
|
|
1099,
|
|
79092,
|
|
46005,
|
|
8,
|
|
],
|
|
),
|
|
(
|
|
'<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST][TOOL_CALLS]▁[{"name":▁"get_weather",▁"arguments":▁{"city":▁"Paris"},▁"id":▁"123456789"}]</s>[TOOL_RESULTS]▁{"content":▁{"temperature":▁20,▁"unit":▁"celsius"},▁"call_id":▁"123456789"}[/TOOL_RESULTS]',
|
|
(
|
|
'<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST][TOOL_CALLS]get_weather[ARGS]{"city": "Paris"}</s>[TOOL_RESULTS]{"temperature": 20, "unit": "celsius"}[/TOOL_RESULTS]' # noqa: E501
|
|
),
|
|
),
|
|
),
|
|
(
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "Hello world !",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": "Hello ",
|
|
},
|
|
],
|
|
},
|
|
False,
|
|
True,
|
|
(
|
|
[1, 3, 23325, 2294, 1686, 4, 23325],
|
|
[1, 3, 22177, 4304, 2662, 4, 22177, 2],
|
|
),
|
|
(
|
|
"<s>[INST]▁Hello▁world▁![/INST]▁Hello",
|
|
("<s>[INST]Hello world ![/INST]Hello</s>"),
|
|
),
|
|
),
|
|
],
|
|
)
|
|
def test_apply_chat_template(
|
|
self,
|
|
mistral_tokenizer: MistralTokenizer,
|
|
openai_request: dict[str, Any],
|
|
add_generation_prompt: bool,
|
|
continue_final_message: bool,
|
|
expected_output: tuple[list[int], list[int]],
|
|
decoded_expected_output: tuple[str, str],
|
|
):
|
|
actual_output = mistral_tokenizer.apply_chat_template(
|
|
openai_request["messages"],
|
|
tools=openai_request.get("tools", []),
|
|
add_generation_prompt=add_generation_prompt,
|
|
continue_final_message=continue_final_message,
|
|
)
|
|
decoded_actual_output = mistral_tokenizer.tokenizer.decode(
|
|
actual_output, SpecialTokenPolicy.KEEP
|
|
)
|
|
|
|
assert actual_output == expected_output[mistral_tokenizer.is_tekken]
|
|
assert (
|
|
decoded_actual_output
|
|
== decoded_expected_output[mistral_tokenizer.is_tekken]
|
|
)
|
|
|
|
def test_apply_chat_template_error(self, mistral_tokenizer: MistralTokenizer):
|
|
messages = [{"role": "user", "content": "Hello world !"}]
|
|
|
|
with pytest.raises(ValueError):
|
|
mistral_tokenizer.apply_chat_template(
|
|
messages,
|
|
tools=[],
|
|
add_generation_prompt=True,
|
|
continue_final_message=True,
|
|
)
|
|
|
|
with pytest.raises(ValueError):
|
|
mistral_tokenizer.apply_chat_template(
|
|
messages,
|
|
tools=[],
|
|
add_generation_prompt=False,
|
|
continue_final_message=True,
|
|
)
|
|
|
|
messages = [
|
|
{"role": "user", "content": "Hello world !"},
|
|
{"role": "assistant", "content": "Hello "},
|
|
]
|
|
with pytest.raises(ValueError):
|
|
mistral_tokenizer.apply_chat_template(
|
|
messages,
|
|
tools=[],
|
|
add_generation_prompt=True,
|
|
continue_final_message=False,
|
|
)
|
|
|
|
messages = [
|
|
{"role": "user", "content": "Hello world !"},
|
|
{"role": "assistant", "content": "Hello "},
|
|
]
|
|
with pytest.raises(InvalidMessageStructureException):
|
|
mistral_tokenizer.apply_chat_template(
|
|
messages,
|
|
tools=[],
|
|
add_generation_prompt=False,
|
|
continue_final_message=False,
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
"skip_special_tokens,expected_tokens",
|
|
(
|
|
(
|
|
False,
|
|
(
|
|
"<s>[INST]▁Hello▁world▁![/INST]▁Hello</s>",
|
|
"<s>[INST]Hello world ![/INST]Hello</s>",
|
|
),
|
|
),
|
|
(True, ("Hello world ! Hello", "Hello world !Hello")),
|
|
),
|
|
)
|
|
def test_decode(
|
|
self,
|
|
mistral_tokenizer: MistralTokenizer,
|
|
skip_special_tokens: bool,
|
|
expected_tokens: tuple[str, str],
|
|
):
|
|
ids = (
|
|
[1, 3, 23325, 2294, 1686, 4, 23325, 2],
|
|
[1, 3, 22177, 4304, 2662, 4, 22177, 2],
|
|
)
|
|
assert (
|
|
mistral_tokenizer.decode(
|
|
ids[mistral_tokenizer.is_tekken],
|
|
skip_special_tokens=skip_special_tokens,
|
|
)
|
|
== expected_tokens[mistral_tokenizer.is_tekken]
|
|
)
|
|
assert (
|
|
mistral_tokenizer.decode(
|
|
ids[mistral_tokenizer.is_tekken],
|
|
skip_special_tokens=skip_special_tokens,
|
|
)
|
|
== expected_tokens[mistral_tokenizer.is_tekken]
|
|
)
|
|
|
|
def test_decode_empty(
|
|
self,
|
|
mistral_tokenizer: MistralTokenizer,
|
|
):
|
|
assert (
|
|
mistral_tokenizer.decode(
|
|
[],
|
|
)
|
|
== ""
|
|
)
|
|
|
|
def test_decode_int(
|
|
self,
|
|
mistral_tokenizer: MistralTokenizer,
|
|
):
|
|
ids = 1
|
|
assert (
|
|
mistral_tokenizer.decode(
|
|
ids,
|
|
skip_special_tokens=False,
|
|
)
|
|
== "<s>"
|
|
)
|
|
|
|
def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer):
|
|
tokens = (
|
|
[
|
|
"<s>",
|
|
"[AVAILABLE_TOOLS]",
|
|
"▁[",
|
|
'{"',
|
|
"type",
|
|
'":',
|
|
'▁"',
|
|
"function",
|
|
'",',
|
|
'▁"',
|
|
"function",
|
|
'":',
|
|
'▁{"',
|
|
"name",
|
|
'":',
|
|
'▁"',
|
|
"get",
|
|
"_",
|
|
"we",
|
|
"ather",
|
|
'",',
|
|
'▁"',
|
|
"description",
|
|
'":',
|
|
'▁"',
|
|
"Get",
|
|
"s",
|
|
"▁the",
|
|
"▁current",
|
|
"▁weather",
|
|
"▁in",
|
|
"▁a",
|
|
"▁city",
|
|
'.",',
|
|
'▁"',
|
|
"parameters",
|
|
'":',
|
|
'▁{"',
|
|
"type",
|
|
'":',
|
|
'▁"',
|
|
"object",
|
|
'",',
|
|
'▁"',
|
|
"properties",
|
|
'":',
|
|
'▁{"',
|
|
"city",
|
|
'":',
|
|
'▁{"',
|
|
"type",
|
|
'":',
|
|
'▁"',
|
|
"string",
|
|
'",',
|
|
'▁"',
|
|
"description",
|
|
'":',
|
|
'▁"',
|
|
"The",
|
|
"▁city",
|
|
"▁name",
|
|
'"',
|
|
"}},",
|
|
'▁"',
|
|
"required",
|
|
'":',
|
|
'▁["',
|
|
"city",
|
|
'"]',
|
|
"}}",
|
|
"}]",
|
|
"[/AVAILABLE_TOOLS]",
|
|
"[INST]",
|
|
"▁I",
|
|
"▁am",
|
|
"▁an",
|
|
"▁AI",
|
|
"<0x0A>",
|
|
"<0x0A>",
|
|
"Hello",
|
|
"▁world",
|
|
"▁!",
|
|
"[/INST]",
|
|
"[TOOL_CALLS]",
|
|
"▁[",
|
|
'{"',
|
|
"name",
|
|
'":',
|
|
'▁"',
|
|
"get",
|
|
"_",
|
|
"we",
|
|
"ather",
|
|
'",',
|
|
'▁"',
|
|
"arguments",
|
|
'":',
|
|
'▁{"',
|
|
"city",
|
|
'":',
|
|
'▁"',
|
|
"Par",
|
|
"is",
|
|
'"},',
|
|
'▁"',
|
|
"id",
|
|
'":',
|
|
'▁"',
|
|
"1",
|
|
"2",
|
|
"3",
|
|
"4",
|
|
"5",
|
|
"6",
|
|
"7",
|
|
"8",
|
|
"9",
|
|
'"',
|
|
"}]",
|
|
"</s>",
|
|
"[TOOL_RESULTS]",
|
|
'▁{"',
|
|
"content",
|
|
'":',
|
|
'▁{"',
|
|
"t",
|
|
"emperature",
|
|
'":',
|
|
"▁",
|
|
"2",
|
|
"0",
|
|
",",
|
|
'▁"',
|
|
"unit",
|
|
'":',
|
|
'▁"',
|
|
"c",
|
|
"els",
|
|
"ius",
|
|
'"},',
|
|
'▁"',
|
|
"call",
|
|
"_",
|
|
"id",
|
|
'":',
|
|
'▁"',
|
|
"1",
|
|
"2",
|
|
"3",
|
|
"4",
|
|
"5",
|
|
"6",
|
|
"7",
|
|
"8",
|
|
"9",
|
|
'"}',
|
|
"[/TOOL_RESULTS]",
|
|
],
|
|
[
|
|
"<s>",
|
|
"[SYSTEM_PROMPT]",
|
|
"I",
|
|
" am",
|
|
" an",
|
|
" AI",
|
|
"[/SYSTEM_PROMPT]",
|
|
"[AVAILABLE_TOOLS]",
|
|
"[",
|
|
'{"',
|
|
"type",
|
|
'":',
|
|
' "',
|
|
"function",
|
|
'",',
|
|
' "',
|
|
"function",
|
|
'":',
|
|
' {"',
|
|
"name",
|
|
'":',
|
|
' "',
|
|
"get",
|
|
"_",
|
|
"weather",
|
|
'",',
|
|
' "',
|
|
"description",
|
|
'":',
|
|
' "',
|
|
"G",
|
|
"ets",
|
|
" the",
|
|
" current",
|
|
" weather",
|
|
" in",
|
|
" a",
|
|
" city",
|
|
'.",',
|
|
' "',
|
|
"parameters",
|
|
'":',
|
|
' {"',
|
|
"type",
|
|
'":',
|
|
' "',
|
|
"object",
|
|
'",',
|
|
' "',
|
|
"properties",
|
|
'":',
|
|
' {"',
|
|
"city",
|
|
'":',
|
|
' {"',
|
|
"type",
|
|
'":',
|
|
' "',
|
|
"string",
|
|
'",',
|
|
' "',
|
|
"description",
|
|
'":',
|
|
' "',
|
|
"The",
|
|
" city",
|
|
" name",
|
|
'"',
|
|
"}},",
|
|
' "',
|
|
"required",
|
|
'":',
|
|
' ["',
|
|
"city",
|
|
'"]',
|
|
"}}",
|
|
"}]",
|
|
"[/AVAILABLE_TOOLS]",
|
|
"[INST]",
|
|
"Hello",
|
|
" world",
|
|
" !",
|
|
"[/INST]",
|
|
"[TOOL_CALLS]",
|
|
"get",
|
|
"_",
|
|
"weather",
|
|
"[ARGS]",
|
|
'{"',
|
|
"city",
|
|
'":',
|
|
' "',
|
|
"Paris",
|
|
'"}',
|
|
"</s>",
|
|
"[TOOL_RESULTS]",
|
|
'{"',
|
|
"temperature",
|
|
'":',
|
|
" ",
|
|
"2",
|
|
"0",
|
|
",",
|
|
' "',
|
|
"unit",
|
|
'":',
|
|
' "',
|
|
"c",
|
|
"elsius",
|
|
'"}',
|
|
"[/TOOL_RESULTS]",
|
|
],
|
|
)
|
|
|
|
expected_strings = (
|
|
'[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}] I am an AI\n\nHello world ![TOOL_CALLS][{"name": "get_weather", "arguments": {"city": "Paris"}, "id": "123456789"}] {"content": {"temperature": 20, "unit": "celsius"}, "call_id": "123456789"}', # noqa: E501
|
|
'I am an AI[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}]Hello world ![TOOL_CALLS]get_weather{"city": "Paris"}{"temperature": 20, "unit": "celsius"}', # noqa: E501
|
|
)
|
|
|
|
assert (
|
|
mistral_tokenizer.convert_tokens_to_string(
|
|
tokens[mistral_tokenizer.is_tekken]
|
|
)
|
|
== expected_strings[mistral_tokenizer.is_tekken]
|
|
)
|
|
|
|
assert mistral_tokenizer.convert_tokens_to_string([]) == ""
|
|
|
|
@pytest.mark.parametrize(
|
|
"skip_special_tokens,tuple_expected_tokens",
|
|
(
|
|
(
|
|
True,
|
|
(
|
|
[
|
|
"▁[",
|
|
'{"',
|
|
"type",
|
|
'":',
|
|
'▁"',
|
|
"function",
|
|
'",',
|
|
'▁"',
|
|
"function",
|
|
'":',
|
|
'▁{"',
|
|
"name",
|
|
'":',
|
|
'▁"',
|
|
"get",
|
|
"_",
|
|
"we",
|
|
"ather",
|
|
'",',
|
|
'▁"',
|
|
"description",
|
|
'":',
|
|
'▁"',
|
|
"Get",
|
|
"s",
|
|
"▁the",
|
|
"▁current",
|
|
"▁weather",
|
|
"▁in",
|
|
"▁a",
|
|
"▁city",
|
|
'.",',
|
|
'▁"',
|
|
"parameters",
|
|
'":',
|
|
'▁{"',
|
|
"type",
|
|
'":',
|
|
'▁"',
|
|
"object",
|
|
'",',
|
|
'▁"',
|
|
"properties",
|
|
'":',
|
|
'▁{"',
|
|
"city",
|
|
'":',
|
|
'▁{"',
|
|
"type",
|
|
'":',
|
|
'▁"',
|
|
"string",
|
|
'",',
|
|
'▁"',
|
|
"description",
|
|
'":',
|
|
'▁"',
|
|
"The",
|
|
"▁city",
|
|
"▁name",
|
|
'"',
|
|
"}},",
|
|
'▁"',
|
|
"required",
|
|
'":',
|
|
'▁["',
|
|
"city",
|
|
'"]',
|
|
"}}",
|
|
"}]",
|
|
"▁I",
|
|
"▁am",
|
|
"▁an",
|
|
"▁AI",
|
|
"<0x0A>",
|
|
"<0x0A>",
|
|
"Hello",
|
|
"▁world",
|
|
"▁!",
|
|
"[TOOL_CALLS]",
|
|
"▁[",
|
|
'{"',
|
|
"name",
|
|
'":',
|
|
'▁"',
|
|
"get",
|
|
"_",
|
|
"we",
|
|
"ather",
|
|
'",',
|
|
'▁"',
|
|
"arguments",
|
|
'":',
|
|
'▁{"',
|
|
"city",
|
|
'":',
|
|
'▁"',
|
|
"Par",
|
|
"is",
|
|
'"},',
|
|
'▁"',
|
|
"id",
|
|
'":',
|
|
'▁"',
|
|
"1",
|
|
"2",
|
|
"3",
|
|
"4",
|
|
"5",
|
|
"6",
|
|
"7",
|
|
"8",
|
|
"9",
|
|
'"',
|
|
"}]",
|
|
'▁{"',
|
|
"content",
|
|
'":',
|
|
'▁{"',
|
|
"t",
|
|
"emperature",
|
|
'":',
|
|
"▁",
|
|
"2",
|
|
"0",
|
|
",",
|
|
'▁"',
|
|
"unit",
|
|
'":',
|
|
'▁"',
|
|
"c",
|
|
"els",
|
|
"ius",
|
|
'"},',
|
|
'▁"',
|
|
"call",
|
|
"_",
|
|
"id",
|
|
'":',
|
|
'▁"',
|
|
"1",
|
|
"2",
|
|
"3",
|
|
"4",
|
|
"5",
|
|
"6",
|
|
"7",
|
|
"8",
|
|
"9",
|
|
'"}',
|
|
],
|
|
[
|
|
"I",
|
|
" am",
|
|
" an",
|
|
" AI",
|
|
"[",
|
|
'{"',
|
|
"type",
|
|
'":',
|
|
' "',
|
|
"function",
|
|
'",',
|
|
' "',
|
|
"function",
|
|
'":',
|
|
' {"',
|
|
"name",
|
|
'":',
|
|
' "',
|
|
"get",
|
|
"_",
|
|
"weather",
|
|
'",',
|
|
' "',
|
|
"description",
|
|
'":',
|
|
' "',
|
|
"G",
|
|
"ets",
|
|
" the",
|
|
" current",
|
|
" weather",
|
|
" in",
|
|
" a",
|
|
" city",
|
|
'.",',
|
|
' "',
|
|
"parameters",
|
|
'":',
|
|
' {"',
|
|
"type",
|
|
'":',
|
|
' "',
|
|
"object",
|
|
'",',
|
|
' "',
|
|
"properties",
|
|
'":',
|
|
' {"',
|
|
"city",
|
|
'":',
|
|
' {"',
|
|
"type",
|
|
'":',
|
|
' "',
|
|
"string",
|
|
'",',
|
|
' "',
|
|
"description",
|
|
'":',
|
|
' "',
|
|
"The",
|
|
" city",
|
|
" name",
|
|
'"',
|
|
"}},",
|
|
' "',
|
|
"required",
|
|
'":',
|
|
' ["',
|
|
"city",
|
|
'"]',
|
|
"}}",
|
|
"}]",
|
|
"Hello",
|
|
" world",
|
|
" !",
|
|
"[TOOL_CALLS]",
|
|
"get",
|
|
"_",
|
|
"weather",
|
|
'{"',
|
|
"city",
|
|
'":',
|
|
' "',
|
|
"Paris",
|
|
'"}',
|
|
'{"',
|
|
"temperature",
|
|
'":',
|
|
" ",
|
|
"2",
|
|
"0",
|
|
",",
|
|
' "',
|
|
"unit",
|
|
'":',
|
|
' "',
|
|
"c",
|
|
"elsius",
|
|
'"}',
|
|
],
|
|
),
|
|
),
|
|
(
|
|
False,
|
|
(
|
|
[
|
|
"<s>",
|
|
"[AVAILABLE_TOOLS]",
|
|
"▁[",
|
|
'{"',
|
|
"type",
|
|
'":',
|
|
'▁"',
|
|
"function",
|
|
'",',
|
|
'▁"',
|
|
"function",
|
|
'":',
|
|
'▁{"',
|
|
"name",
|
|
'":',
|
|
'▁"',
|
|
"get",
|
|
"_",
|
|
"we",
|
|
"ather",
|
|
'",',
|
|
'▁"',
|
|
"description",
|
|
'":',
|
|
'▁"',
|
|
"Get",
|
|
"s",
|
|
"▁the",
|
|
"▁current",
|
|
"▁weather",
|
|
"▁in",
|
|
"▁a",
|
|
"▁city",
|
|
'.",',
|
|
'▁"',
|
|
"parameters",
|
|
'":',
|
|
'▁{"',
|
|
"type",
|
|
'":',
|
|
'▁"',
|
|
"object",
|
|
'",',
|
|
'▁"',
|
|
"properties",
|
|
'":',
|
|
'▁{"',
|
|
"city",
|
|
'":',
|
|
'▁{"',
|
|
"type",
|
|
'":',
|
|
'▁"',
|
|
"string",
|
|
'",',
|
|
'▁"',
|
|
"description",
|
|
'":',
|
|
'▁"',
|
|
"The",
|
|
"▁city",
|
|
"▁name",
|
|
'"',
|
|
"}},",
|
|
'▁"',
|
|
"required",
|
|
'":',
|
|
'▁["',
|
|
"city",
|
|
'"]',
|
|
"}}",
|
|
"}]",
|
|
"[/AVAILABLE_TOOLS]",
|
|
"[INST]",
|
|
"▁I",
|
|
"▁am",
|
|
"▁an",
|
|
"▁AI",
|
|
"<0x0A>",
|
|
"<0x0A>",
|
|
"Hello",
|
|
"▁world",
|
|
"▁!",
|
|
"[/INST]",
|
|
"[TOOL_CALLS]",
|
|
"▁[",
|
|
'{"',
|
|
"name",
|
|
'":',
|
|
'▁"',
|
|
"get",
|
|
"_",
|
|
"we",
|
|
"ather",
|
|
'",',
|
|
'▁"',
|
|
"arguments",
|
|
'":',
|
|
'▁{"',
|
|
"city",
|
|
'":',
|
|
'▁"',
|
|
"Par",
|
|
"is",
|
|
'"},',
|
|
'▁"',
|
|
"id",
|
|
'":',
|
|
'▁"',
|
|
"1",
|
|
"2",
|
|
"3",
|
|
"4",
|
|
"5",
|
|
"6",
|
|
"7",
|
|
"8",
|
|
"9",
|
|
'"',
|
|
"}]",
|
|
"</s>",
|
|
"[TOOL_RESULTS]",
|
|
'▁{"',
|
|
"content",
|
|
'":',
|
|
'▁{"',
|
|
"t",
|
|
"emperature",
|
|
'":',
|
|
"▁",
|
|
"2",
|
|
"0",
|
|
",",
|
|
'▁"',
|
|
"unit",
|
|
'":',
|
|
'▁"',
|
|
"c",
|
|
"els",
|
|
"ius",
|
|
'"},',
|
|
'▁"',
|
|
"call",
|
|
"_",
|
|
"id",
|
|
'":',
|
|
'▁"',
|
|
"1",
|
|
"2",
|
|
"3",
|
|
"4",
|
|
"5",
|
|
"6",
|
|
"7",
|
|
"8",
|
|
"9",
|
|
'"}',
|
|
"[/TOOL_RESULTS]",
|
|
],
|
|
[
|
|
"<s>",
|
|
"[SYSTEM_PROMPT]",
|
|
"I",
|
|
" am",
|
|
" an",
|
|
" AI",
|
|
"[/SYSTEM_PROMPT]",
|
|
"[AVAILABLE_TOOLS]",
|
|
"[",
|
|
'{"',
|
|
"type",
|
|
'":',
|
|
' "',
|
|
"function",
|
|
'",',
|
|
' "',
|
|
"function",
|
|
'":',
|
|
' {"',
|
|
"name",
|
|
'":',
|
|
' "',
|
|
"get",
|
|
"_",
|
|
"weather",
|
|
'",',
|
|
' "',
|
|
"description",
|
|
'":',
|
|
' "',
|
|
"G",
|
|
"ets",
|
|
" the",
|
|
" current",
|
|
" weather",
|
|
" in",
|
|
" a",
|
|
" city",
|
|
'.",',
|
|
' "',
|
|
"parameters",
|
|
'":',
|
|
' {"',
|
|
"type",
|
|
'":',
|
|
' "',
|
|
"object",
|
|
'",',
|
|
' "',
|
|
"properties",
|
|
'":',
|
|
' {"',
|
|
"city",
|
|
'":',
|
|
' {"',
|
|
"type",
|
|
'":',
|
|
' "',
|
|
"string",
|
|
'",',
|
|
' "',
|
|
"description",
|
|
'":',
|
|
' "',
|
|
"The",
|
|
" city",
|
|
" name",
|
|
'"',
|
|
"}},",
|
|
' "',
|
|
"required",
|
|
'":',
|
|
' ["',
|
|
"city",
|
|
'"]',
|
|
"}}",
|
|
"}]",
|
|
"[/AVAILABLE_TOOLS]",
|
|
"[INST]",
|
|
"Hello",
|
|
" world",
|
|
" !",
|
|
"[/INST]",
|
|
"[TOOL_CALLS]",
|
|
"get",
|
|
"_",
|
|
"weather",
|
|
"[ARGS]",
|
|
'{"',
|
|
"city",
|
|
'":',
|
|
' "',
|
|
"Paris",
|
|
'"}',
|
|
"</s>",
|
|
"[TOOL_RESULTS]",
|
|
'{"',
|
|
"temperature",
|
|
'":',
|
|
" ",
|
|
"2",
|
|
"0",
|
|
",",
|
|
' "',
|
|
"unit",
|
|
'":',
|
|
' "',
|
|
"c",
|
|
"elsius",
|
|
'"}',
|
|
"[/TOOL_RESULTS]",
|
|
],
|
|
),
|
|
),
|
|
),
|
|
)
|
|
def test_convert_ids_to_tokens(
|
|
self,
|
|
mistral_tokenizer: MistralTokenizer,
|
|
skip_special_tokens: bool,
|
|
tuple_expected_tokens: tuple[list[str], list[str]],
|
|
):
|
|
tuple_ids = (
|
|
[
|
|
1,
|
|
6,
|
|
1501,
|
|
7567,
|
|
1891,
|
|
2032,
|
|
1113,
|
|
3396,
|
|
1316,
|
|
1113,
|
|
3396,
|
|
2032,
|
|
10598,
|
|
1629,
|
|
2032,
|
|
1113,
|
|
1295,
|
|
29498,
|
|
1537,
|
|
1991,
|
|
1316,
|
|
1113,
|
|
7286,
|
|
2032,
|
|
1113,
|
|
2226,
|
|
29481,
|
|
1040,
|
|
2636,
|
|
8854,
|
|
1065,
|
|
1032,
|
|
3758,
|
|
9959,
|
|
1113,
|
|
12206,
|
|
2032,
|
|
10598,
|
|
1891,
|
|
2032,
|
|
1113,
|
|
3582,
|
|
1316,
|
|
1113,
|
|
11491,
|
|
2032,
|
|
10598,
|
|
19141,
|
|
2032,
|
|
10598,
|
|
1891,
|
|
2032,
|
|
1113,
|
|
2195,
|
|
1316,
|
|
1113,
|
|
7286,
|
|
2032,
|
|
1113,
|
|
1782,
|
|
3758,
|
|
1909,
|
|
29507,
|
|
11549,
|
|
1113,
|
|
11661,
|
|
2032,
|
|
8135,
|
|
19141,
|
|
3010,
|
|
1743,
|
|
10925,
|
|
7,
|
|
3,
|
|
1083,
|
|
1605,
|
|
1164,
|
|
16875,
|
|
781,
|
|
781,
|
|
16998,
|
|
2294,
|
|
1686,
|
|
4,
|
|
5,
|
|
1501,
|
|
7567,
|
|
1629,
|
|
2032,
|
|
1113,
|
|
1295,
|
|
29498,
|
|
1537,
|
|
1991,
|
|
1316,
|
|
1113,
|
|
17452,
|
|
2032,
|
|
10598,
|
|
19141,
|
|
2032,
|
|
1113,
|
|
4684,
|
|
1046,
|
|
8474,
|
|
1113,
|
|
1081,
|
|
2032,
|
|
1113,
|
|
29508,
|
|
29518,
|
|
29538,
|
|
29549,
|
|
29550,
|
|
29552,
|
|
29555,
|
|
29551,
|
|
29542,
|
|
29507,
|
|
10925,
|
|
2,
|
|
8,
|
|
10598,
|
|
4557,
|
|
2032,
|
|
10598,
|
|
29475,
|
|
17329,
|
|
2032,
|
|
29473,
|
|
29518,
|
|
29502,
|
|
29493,
|
|
1113,
|
|
6074,
|
|
2032,
|
|
1113,
|
|
29485,
|
|
1958,
|
|
3938,
|
|
8474,
|
|
1113,
|
|
3613,
|
|
29498,
|
|
1081,
|
|
2032,
|
|
1113,
|
|
29508,
|
|
29518,
|
|
29538,
|
|
29549,
|
|
29550,
|
|
29552,
|
|
29555,
|
|
29551,
|
|
29542,
|
|
18163,
|
|
9,
|
|
],
|
|
[
|
|
1,
|
|
17,
|
|
1073,
|
|
1855,
|
|
1420,
|
|
26554,
|
|
18,
|
|
5,
|
|
1091,
|
|
19227,
|
|
4994,
|
|
2811,
|
|
1429,
|
|
5165,
|
|
1897,
|
|
1429,
|
|
5165,
|
|
2811,
|
|
16753,
|
|
2391,
|
|
2811,
|
|
1429,
|
|
1689,
|
|
1095,
|
|
45629,
|
|
1897,
|
|
1429,
|
|
14653,
|
|
2811,
|
|
1429,
|
|
1071,
|
|
3083,
|
|
1278,
|
|
3519,
|
|
17253,
|
|
1294,
|
|
1261,
|
|
5970,
|
|
39249,
|
|
1429,
|
|
26204,
|
|
2811,
|
|
16753,
|
|
4994,
|
|
2811,
|
|
1429,
|
|
6371,
|
|
1897,
|
|
1429,
|
|
48649,
|
|
2811,
|
|
16753,
|
|
29363,
|
|
2811,
|
|
16753,
|
|
4994,
|
|
2811,
|
|
1429,
|
|
3607,
|
|
1897,
|
|
1429,
|
|
14653,
|
|
2811,
|
|
1429,
|
|
1784,
|
|
5970,
|
|
2564,
|
|
1034,
|
|
47579,
|
|
1429,
|
|
15760,
|
|
2811,
|
|
12161,
|
|
29363,
|
|
4964,
|
|
2821,
|
|
27028,
|
|
6,
|
|
3,
|
|
22177,
|
|
4304,
|
|
2662,
|
|
4,
|
|
9,
|
|
1689,
|
|
1095,
|
|
45629,
|
|
32,
|
|
19227,
|
|
29363,
|
|
2811,
|
|
1429,
|
|
42572,
|
|
46005,
|
|
2,
|
|
7,
|
|
19227,
|
|
113824,
|
|
2811,
|
|
1032,
|
|
1050,
|
|
1048,
|
|
1044,
|
|
1429,
|
|
8979,
|
|
2811,
|
|
1429,
|
|
1099,
|
|
79092,
|
|
46005,
|
|
8,
|
|
],
|
|
)
|
|
|
|
ids = tuple_ids[mistral_tokenizer.is_tekken]
|
|
expected_tokens = tuple_expected_tokens[mistral_tokenizer.is_tekken]
|
|
actual_tokens = mistral_tokenizer.convert_ids_to_tokens(
|
|
ids, skip_special_tokens=skip_special_tokens
|
|
)
|
|
assert actual_tokens == expected_tokens
|
|
|
|
assert mistral_tokenizer.convert_ids_to_tokens([]) == []
|