[Misc] Update TokenizerLike interface and move get_cached_tokenizer (#29730)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-11-30 14:59:47 +08:00
committed by GitHub
parent 9381b5cde0
commit 2afcec4dec
15 changed files with 260 additions and 174 deletions

View File

@@ -61,8 +61,8 @@ steps:
- pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s -m 'not cpu_test' multimodal
- pytest -v -s utils_ - pytest -v -s utils_
- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins - label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min
timeout_in_minutes: 10 timeout_in_minutes: 20
mirror_hardwares: [amdexperimental, amdproduction] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking # grade: Blocking
@@ -72,6 +72,7 @@ steps:
- tests/test_outputs.py - tests/test_outputs.py
- tests/multimodal - tests/multimodal
- tests/standalone_tests/lazy_imports.py - tests/standalone_tests/lazy_imports.py
- tests/tokenizers_
- tests/transformers_utils - tests/transformers_utils
- tests/config - tests/config
no_gpu: true no_gpu: true
@@ -80,6 +81,7 @@ steps:
- pytest -v -s test_inputs.py - pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py - pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal - pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s tokenizers_
- pytest -v -s transformers_utils - pytest -v -s transformers_utils
- pytest -v -s config - pytest -v -s config
@@ -308,23 +310,20 @@ steps:
- pytest -v -s test_regression.py - pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional working_dir: "/vllm-workspace/tests" # optional
- label: Engine Test # 25min - label: Engine Test # 9min
timeout_in_minutes: 40 timeout_in_minutes: 15
mirror_hardwares: [amdexperimental, amdproduction] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking # grade: Blocking
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/engine - tests/engine
- tests/tokenizers_
- tests/test_sequence - tests/test_sequence
- tests/test_config - tests/test_config
- tests/test_logger - tests/test_logger
- tests/test_vllm_port - tests/test_vllm_port
commands: commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately
- pytest -v -s tokenizers_
- label: V1 Test e2e + engine # 30min - label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45 timeout_in_minutes: 45

View File

@@ -57,14 +57,15 @@ steps:
- pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s -m 'not cpu_test' multimodal
- pytest -v -s utils_ - pytest -v -s utils_
- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins - label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min
timeout_in_minutes: 10 timeout_in_minutes: 20
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/test_inputs.py - tests/test_inputs.py
- tests/test_outputs.py - tests/test_outputs.py
- tests/multimodal - tests/multimodal
- tests/standalone_tests/lazy_imports.py - tests/standalone_tests/lazy_imports.py
- tests/tokenizers_
- tests/transformers_utils - tests/transformers_utils
- tests/config - tests/config
no_gpu: true no_gpu: true
@@ -73,6 +74,7 @@ steps:
- pytest -v -s test_inputs.py - pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py - pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal - pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s tokenizers_
- pytest -v -s transformers_utils - pytest -v -s transformers_utils
- pytest -v -s config - pytest -v -s config
@@ -276,21 +278,18 @@ steps:
- pytest -v -s test_regression.py - pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional working_dir: "/vllm-workspace/tests" # optional
- label: Engine Test # 25min - label: Engine Test # 9min
timeout_in_minutes: 40 timeout_in_minutes: 15
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/engine - tests/engine
- tests/tokenizers_
- tests/test_sequence - tests/test_sequence
- tests/test_config - tests/test_config
- tests/test_logger - tests/test_logger
- tests/test_vllm_port - tests/test_vllm_port
commands: commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately
- pytest -v -s tokenizers_
- label: V1 Test e2e + engine # 30min - label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45 timeout_in_minutes: 45

View File

@@ -21,7 +21,7 @@ Let's say we want to serve the popular Qwen model by running `vllm serve Qwen/Qw
Beyond that, there are two more things vLLM depends on Hugging Face for. Beyond that, there are two more things vLLM depends on Hugging Face for.
1. **Tokenizer**: vLLM uses the tokenizer from Hugging Face to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check Hugging Face's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [get_cached_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L24). 1. **Tokenizer**: vLLM uses the tokenizer from Hugging Face to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check Hugging Face's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [vllm.tokenizers.hf.get_cached_tokenizer][].
2. **Model weight**: vLLM downloads the model weight from the Hugging Face model hub using the `model` argument as the model name and the `--revision` argument as the revision. vLLM provides the argument `--load-format` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass `--load-format dummy` to skip downloading the weights. 2. **Model weight**: vLLM downloads the model weight from the Hugging Face model hub using the `model` argument as the model name and the `--revision` argument as the revision. vLLM provides the argument `--load-format` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass `--load-format dummy` to skip downloading the weights.
- It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the [documentation](https://huggingface.co/docs/safetensors/en/index) for more information on the safetensors format. This part of the logic can be found [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385). Please note that: - It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the [documentation](https://huggingface.co/docs/safetensors/en/index) for more information on the safetensors format. This part of the logic can be found [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385). Please note that:

View File

@@ -7,7 +7,7 @@ import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizer import get_cached_tokenizer from vllm.tokenizers.hf import get_cached_tokenizer
@pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"]) @pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"])

View File

@@ -356,8 +356,8 @@ class TestMistralTokenizer:
) )
attn_mask = [1 for _ in range(len(token_ids))] attn_mask = [1 for _ in range(len(token_ids))]
# Test 1: default # Test 1: no special tokens
assert mistral_tokenizer("Hello world !") == { assert mistral_tokenizer("Hello world !", add_special_tokens=False) == {
"attention_mask": attn_mask[1:], "attention_mask": attn_mask[1:],
"input_ids": token_ids[1:], "input_ids": token_ids[1:],
} }
@@ -381,7 +381,7 @@ class TestMistralTokenizer:
"input_ids": token_ids, "input_ids": token_ids,
} }
# Test 5: empty string # Test 5: empty string
assert mistral_tokenizer("") == { assert mistral_tokenizer("", add_special_tokens=False) == {
"attention_mask": [], "attention_mask": [],
"input_ids": [], "input_ids": [],
} }

View File

@@ -17,20 +17,26 @@ class TestTokenizer(TokenizerLike):
def eos_token_id(self) -> int: def eos_token_id(self) -> int:
return 1 return 1
@property
def pad_token_id(self) -> int:
return 2
@property
def is_fast(self) -> bool:
return True
def test_customized_tokenizer(): def test_customized_tokenizer():
TokenizerRegistry.register( TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__)
"test_tokenizer",
__name__,
TestTokenizer.__name__,
)
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer") tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
assert isinstance(tokenizer, TestTokenizer) assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0 assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1 assert tokenizer.eos_token_id == 1
assert tokenizer.pad_token_id == 2
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom") tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
assert isinstance(tokenizer, TestTokenizer) assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0 assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1 assert tokenizer.eos_token_id == 1
assert tokenizer.pad_token_id == 2

View File

@@ -27,7 +27,7 @@ ALLOWED_FILES = {
"vllm/distributed/device_communicators/shm_broadcast.py", "vllm/distributed/device_communicators/shm_broadcast.py",
"vllm/distributed/device_communicators/shm_object_storage.py", "vllm/distributed/device_communicators/shm_object_storage.py",
"vllm/utils/hashing.py", "vllm/utils/hashing.py",
"tests/tokenizers_/test_cached_tokenizer.py", "tests/tokenizers_/test_hf.py",
"tests/utils_/test_hashing.py", "tests/utils_/test_hashing.py",
"benchmarks/kernels/graph_machete_bench.py", "benchmarks/kernels/graph_machete_bench.py",
"benchmarks/kernels/benchmark_lora.py", "benchmarks/kernels/benchmark_lora.py",

View File

@@ -72,7 +72,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.tokenizers import MistralTokenizer, TokenizerLike from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.tokenizer import get_cached_tokenizer from vllm.tokenizers.hf import get_cached_tokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter from vllm.utils.counter import Counter

View File

@@ -51,8 +51,8 @@ def _cosine_similarity(
for emb_1, emb_2 in zip(embed_1, embed_2): for emb_1, emb_2 in zip(embed_1, embed_2):
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data) pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
padding = [] padding: list[int] = []
if (pad_token_id := getattr(tokenizer, "pad_token_id", None)) is not None: if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id] padding = [pad_token_id]
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

View File

@@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .hf import HfTokenizer
from .mistral import MistralTokenizer from .mistral import MistralTokenizer
from .protocol import TokenizerLike from .protocol import TokenizerLike
from .registry import TokenizerRegistry from .registry import TokenizerRegistry
__all__ = ["TokenizerLike", "MistralTokenizer", "TokenizerRegistry"] __all__ = ["TokenizerLike", "HfTokenizer", "MistralTokenizer", "TokenizerRegistry"]

122
vllm/tokenizers/hf.py Normal file
View File

@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import copy
from pathlib import Path
from typing import TYPE_CHECKING
from transformers import AutoTokenizer
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
from .protocol import TokenizerLike
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
def get_cached_tokenizer(
tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast",
) -> TokenizerLike:
"""
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown.
This proxy caches these properties for faster access.
"""
cached_tokenizer = copy.copy(tokenizer)
tokenizer_all_special_ids = tokenizer.all_special_ids
tokenizer_all_special_tokens = tokenizer.all_special_tokens
tokenizer_vocab = tokenizer.get_vocab()
tokenizer_len = len(tokenizer)
max_token_id = max(tokenizer_vocab.values())
# Some tokenizers (e.g., QwenTokenizer) have special tokens that
# are added and included in the implementation of the vocab_size
# property, but not in get_vocab(); if there is an implementation
# of vocab size, we should take the greater value.
if hasattr(tokenizer, "vocab_size"):
with contextlib.suppress(NotImplementedError):
max_token_id = max(max_token_id, tokenizer.vocab_size)
class CachedTokenizer(tokenizer.__class__): # type: ignore
@property
def all_special_ids(self) -> list[int]:
return tokenizer_all_special_ids
@property
def all_special_tokens(self) -> list[str]:
return tokenizer_all_special_tokens
@property
def max_token_id(self) -> int:
return max_token_id
def get_vocab(self) -> dict[str, int]:
return tokenizer_vocab
def __len__(self) -> int:
return tokenizer_len
def __reduce__(self):
return get_cached_tokenizer, (tokenizer,)
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
cached_tokenizer.__class__ = CachedTokenizer
return cached_tokenizer # type: ignore
class HfTokenizer(TokenizerLike):
@classmethod
def from_pretrained(
cls,
path_or_repo_id: str | Path,
*args,
trust_remote_code: bool = False,
revision: str | None = None,
download_dir: str | None = None,
**kwargs,
) -> "TokenizerLike":
try:
tokenizer = AutoTokenizer.from_pretrained(
path_or_repo_id,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
cache_dir=download_dir,
**kwargs,
)
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported,
# suggest using the --trust-remote-code flag.
if not trust_remote_code and (
"does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e)
):
err_msg = (
"Failed to load the tokenizer. If the tokenizer "
"is a custom tokenizer not yet available in the "
"HuggingFace transformers library, consider "
"setting `trust_remote_code=True` in LLM or using "
"the `--trust-remote-code` flag in the CLI."
)
raise RuntimeError(err_msg) from e
else:
raise e
# The special_tokens in tokenizer should also be
# controlled by do_lower_case in encoder_config
encoder_config = get_sentence_transformer_tokenizer_config(
path_or_repo_id, revision
)
if isinstance(encoder_config, dict) and encoder_config.get(
"do_lower_case", False
):
special_tokens_map = {
k: v.lower() for k, v in tokenizer.special_tokens_map.items()
}
tokenizer.add_special_tokens(special_tokens_map)
return get_cached_tokenizer(tokenizer)

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from vllm.logger import init_logger from vllm.logger import init_logger
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
ChatCompletionRequest as MistralChatCompletionRequest, ChatCompletionRequest as MistralChatCompletionRequest,
) )
from mistral_common.tokens.tokenizers.tekken import Tekkenizer from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from transformers import BatchEncoding
from transformers.tokenization_mistral_common import ( from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as TransformersMistralTokenizer, MistralCommonTokenizer as TransformersMistralTokenizer,
) )
@@ -165,7 +166,35 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
class MistralTokenizer(TokenizerLike): class MistralTokenizer(TokenizerLike):
@classmethod
def from_pretrained(
cls,
path_or_repo_id: str | Path,
*args,
trust_remote_code: bool = False,
revision: str | None = None,
download_dir: str | None = None,
**kwargs,
) -> "MistralTokenizer":
from mistral_common.protocol.instruct.validator import ValidationMode
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as TransformersMistralTokenizer,
)
tokenizer = TransformersMistralTokenizer.from_pretrained(
path_or_repo_id,
*args,
mode=ValidationMode.test,
cache_dir=download_dir,
revision="main" if revision is None else revision,
**kwargs,
)
return cls(tokenizer)
def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None: def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
super().__init__()
from mistral_common.protocol.instruct.validator import ValidationMode from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.sentencepiece import ( from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer, SentencePieceTokenizer,
@@ -211,22 +240,6 @@ class MistralTokenizer(TokenizerLike):
self._vocab = self.tokenizer._vocab self._vocab = self.tokenizer._vocab
self._max_token_id = self.vocab_size - 1 self._max_token_id = self.vocab_size - 1
@classmethod
def from_pretrained(
cls, path_or_repo_id: str, *, revision: str | None = None
) -> "MistralTokenizer":
from mistral_common.protocol.instruct.validator import ValidationMode
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as TransformersMistralTokenizer,
)
str_revision = "main" if revision is None else revision
return cls(
TransformersMistralTokenizer.from_pretrained(
path_or_repo_id, revision=str_revision, mode=ValidationMode.test
)
)
def _get_special_token_ids(self) -> list[int]: def _get_special_token_ids(self) -> list[int]:
from mistral_common.tokens.tokenizers.sentencepiece import ( from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer, SentencePieceTokenizer,
@@ -271,6 +284,10 @@ class MistralTokenizer(TokenizerLike):
def eos_token_id(self) -> int: def eos_token_id(self) -> int:
return self.tokenizer.eos_id return self.tokenizer.eos_id
@property
def pad_token_id(self) -> int:
return self.tokenizer.pad_id
@property @property
def is_fast(self) -> bool: def is_fast(self) -> bool:
return True return True
@@ -298,12 +315,12 @@ class MistralTokenizer(TokenizerLike):
def __call__( def __call__(
self, self,
text: str | list[str] | list[int], text: str | list[str],
text_pair: str | None = None, text_pair: str | None = None,
add_special_tokens: bool = False, add_special_tokens: bool = True,
truncation: bool = False, truncation: bool = False,
max_length: int | None = None, max_length: int | None = None,
): ) -> "BatchEncoding":
if text_pair is not None: if text_pair is not None:
raise ValueError( raise ValueError(
"`text_pair` is not supported by `MistralTokenizer.__call__`." "`text_pair` is not supported by `MistralTokenizer.__call__`."
@@ -342,13 +359,11 @@ class MistralTokenizer(TokenizerLike):
text: str, text: str,
truncation: bool | None = None, truncation: bool | None = None,
max_length: int | None = None, max_length: int | None = None,
add_special_tokens: bool | None = None, add_special_tokens: bool = True,
) -> list[int]: ) -> list[int]:
# TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962 # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
# is in, directly call self.transformers_tokenizer.encode(...). # is in, directly call self.transformers_tokenizer.encode(...).
encoded = self.tokenizer.encode( encoded = self.tokenizer.encode(text, bos=add_special_tokens, eos=False)
text, bos=add_special_tokens is not False, eos=False
)
if truncation is not False and max_length is not None: if truncation is not False and max_length is not None:
return encoded[:max_length] return encoded[:max_length]
@@ -383,7 +398,7 @@ class MistralTokenizer(TokenizerLike):
return_dict=False, return_dict=False,
) )
def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str: def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
# TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962 # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
# is in, directly call self.transformers_tokenizer.decode(...). # is in, directly call self.transformers_tokenizer.decode(...).
if isinstance(ids, int): if isinstance(ids, int):
@@ -455,7 +470,7 @@ class MistralTokenizer(TokenizerLike):
def convert_ids_to_tokens( def convert_ids_to_tokens(
self, self,
ids: list[int], ids: list[int],
skip_special_tokens: bool = True, skip_special_tokens: bool = False,
) -> list[str]: ) -> list[str]:
from mistral_common.tokens.tokenizers.base import ( from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy, SpecialTokenPolicy,

View File

@@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol from typing import TYPE_CHECKING, Any, Protocol
from typing_extensions import Self
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import BatchEncoding
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@@ -13,11 +13,13 @@ class TokenizerLike(Protocol):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
pretrained_model_name_or_path: str, path_or_repo_id: str | Path,
/, *args,
*, trust_remote_code: bool = False,
revision: str | None = None, revision: str | None = None,
) -> Self: download_dir: str | None = None,
**kwargs,
) -> "TokenizerLike":
raise NotImplementedError raise NotImplementedError
@property @property
@@ -36,6 +38,10 @@ class TokenizerLike(Protocol):
def eos_token_id(self) -> int: def eos_token_id(self) -> int:
raise NotImplementedError raise NotImplementedError
@property
def pad_token_id(self) -> int:
raise NotImplementedError
@property @property
def is_fast(self) -> bool: def is_fast(self) -> bool:
raise NotImplementedError raise NotImplementedError
@@ -60,12 +66,12 @@ class TokenizerLike(Protocol):
def __call__( def __call__(
self, self,
text: str | list[str] | list[int], text: str | list[str],
text_pair: str | None = None, text_pair: str | None = None,
add_special_tokens: bool = False, add_special_tokens: bool = True,
truncation: bool = False, truncation: bool = False,
max_length: int | None = None, max_length: int | None = None,
): ) -> "BatchEncoding":
raise NotImplementedError raise NotImplementedError
def get_vocab(self) -> dict[str, int]: def get_vocab(self) -> dict[str, int]:
@@ -79,7 +85,7 @@ class TokenizerLike(Protocol):
text: str, text: str,
truncation: bool | None = None, truncation: bool | None = None,
max_length: int | None = None, max_length: int | None = None,
add_special_tokens: bool | None = None, add_special_tokens: bool = True,
) -> list[int]: ) -> list[int]:
raise NotImplementedError raise NotImplementedError
@@ -94,12 +100,12 @@ class TokenizerLike(Protocol):
def convert_tokens_to_string(self, tokens: list[str]) -> str: def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError raise NotImplementedError
def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str: def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
raise NotImplementedError raise NotImplementedError
def convert_ids_to_tokens( def convert_ids_to_tokens(
self, self,
ids: list[int], ids: list[int],
skip_special_tokens: bool = True, skip_special_tokens: bool = False,
) -> list[str]: ) -> list[str]:
raise NotImplementedError raise NotImplementedError

View File

@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import copy
import importlib.util import importlib.util
import os import os
import warnings import warnings
@@ -11,14 +9,17 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import huggingface_hub import huggingface_hub
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import MistralTokenizer, TokenizerLike, TokenizerRegistry from vllm.tokenizers import (
HfTokenizer,
MistralTokenizer,
TokenizerLike,
TokenizerRegistry,
)
from .config import get_sentence_transformer_tokenizer_config
from .gguf_utils import get_gguf_file_path_from_hf from .gguf_utils import get_gguf_file_path_from_hf
from .repo_utils import list_filtered_repo_files from .repo_utils import list_filtered_repo_files
from .utils import check_gguf_file, is_gguf, is_remote_gguf, split_remote_gguf from .utils import check_gguf_file, is_gguf, is_remote_gguf, split_remote_gguf
@@ -41,6 +42,18 @@ def __getattr__(name: str):
) )
return TokenizerLike return TokenizerLike
if name == "get_cached_tokenizer":
from vllm.tokenizers.hf import get_cached_tokenizer
warnings.warn(
"`vllm.transformers_utils.tokenizer.get_cached_tokenizer` "
"has been moved to `vllm.tokenizers.hf.get_cached_tokenizer`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
return get_cached_tokenizer
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -58,10 +71,12 @@ def decode_tokens(
`skip_special_tokens=None` means to use the backend's default `skip_special_tokens=None` means to use the backend's default
settings. settings.
""" """
if skip_special_tokens is not None: kw_args: dict[str, Any] = {}
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
return tokenizer.decode(token_ids) if skip_special_tokens is not None:
kw_args["skip_special_tokens"] = skip_special_tokens
return tokenizer.decode(token_ids, **kw_args)
def encode_tokens( def encode_tokens(
@@ -93,56 +108,6 @@ def encode_tokens(
return tokenizer.encode(text, **kw_args) return tokenizer.encode(text, **kw_args)
def get_cached_tokenizer(tokenizer: TokenizerLike) -> TokenizerLike:
"""
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown.
This proxy caches these properties for faster access.
"""
cached_tokenizer = copy.copy(tokenizer)
tokenizer_all_special_ids = tokenizer.all_special_ids
tokenizer_all_special_tokens = tokenizer.all_special_tokens
tokenizer_vocab = tokenizer.get_vocab()
tokenizer_len = len(tokenizer)
max_token_id = max(tokenizer_vocab.values())
# Some tokenizers (e.g., QwenTokenizer) have special tokens that
# are added and included in the implementation of the vocab_size
# property, but not in get_vocab(); if there is an implementation
# of vocab size, we should take the greater value.
if hasattr(tokenizer, "vocab_size"):
with contextlib.suppress(NotImplementedError):
max_token_id = max(max_token_id, tokenizer.vocab_size)
class CachedTokenizer(tokenizer.__class__): # type: ignore
@property
def all_special_ids(self) -> list[int]:
return tokenizer_all_special_ids
@property
def all_special_tokens(self) -> list[str]:
return tokenizer_all_special_tokens
@property
def max_token_id(self) -> int:
return max_token_id
def get_vocab(self) -> dict[str, int]:
return tokenizer_vocab
def __len__(self) -> int:
return tokenizer_len
def __reduce__(self):
return get_cached_tokenizer, (tokenizer,)
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
cached_tokenizer.__class__ = CachedTokenizer
return cached_tokenizer
def get_tokenizer( def get_tokenizer(
tokenizer_name: str | Path, tokenizer_name: str | Path,
*args, *args,
@@ -217,66 +182,39 @@ def get_tokenizer(
if tokenizer_mode == "mistral": if tokenizer_mode == "mistral":
logger.debug_once(f"Loading MistralTokenizer from {tokenizer_name}") logger.debug_once(f"Loading MistralTokenizer from {tokenizer_name}")
tokenizer = MistralTokenizer.from_pretrained( tokenizer = MistralTokenizer.from_pretrained(
str(tokenizer_name), revision=revision tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
download_dir=download_dir,
**kwargs,
) )
elif tokenizer_mode == "custom": elif tokenizer_mode == "custom":
logger.debug_once(f"Loading CustomTokenizer from {tokenizer_name}") logger.debug_once(f"Loading CustomTokenizer from {tokenizer_name}")
tokenizer = TokenizerRegistry.get_tokenizer( tokenizer = TokenizerRegistry.get_tokenizer(
str(tokenizer_name), str(tokenizer_name),
*args, *args,
trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
download_dir=download_dir, download_dir=download_dir,
**kwargs, **kwargs,
) )
else: else:
try: logger.debug_once(f"Loading HfTokenizer from {tokenizer_name}")
logger.debug_once(f"Loading AutoTokenizer from {tokenizer_name}") tokenizer = HfTokenizer.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained( tokenizer_name,
tokenizer_name, *args,
*args, trust_remote_code=trust_remote_code,
trust_remote_code=trust_remote_code, revision=revision,
revision=revision, download_dir=download_dir,
**kwargs, **kwargs,
)
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported,
# suggest using the --trust-remote-code flag.
if not trust_remote_code and (
"does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e)
):
err_msg = (
"Failed to load the tokenizer. If the tokenizer "
"is a custom tokenizer not yet available in the "
"HuggingFace transformers library, consider "
"setting `trust_remote_code=True` in LLM or using "
"the `--trust-remote-code` flag in the CLI."
)
raise RuntimeError(err_msg) from e
else:
raise e
# The special_tokens in tokenizer should also be
# controlled by do_lower_case in encoder_config
encoder_config = get_sentence_transformer_tokenizer_config(
tokenizer_name, revision
) )
if isinstance(encoder_config, dict) and encoder_config.get(
"do_lower_case", False
):
assert isinstance(tokenizer, PreTrainedTokenizerBase)
special_tokens_map = {
k: v.lower() for k, v in tokenizer.special_tokens_map.items()
}
tokenizer.add_special_tokens(special_tokens_map)
if not tokenizer.is_fast: if not tokenizer.is_fast:
logger.warning( logger.warning(
"Using a slow tokenizer. This might cause a significant " "Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead." "slowdown. Consider using a fast tokenizer instead."
) )
tokenizer = get_cached_tokenizer(tokenizer)
return tokenizer return tokenizer

View File

@@ -9,8 +9,8 @@ from tokenizers.decoders import DecodeStream
from transformers import PreTrainedTokenizerFast from transformers import PreTrainedTokenizerFast
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.detokenizer_utils import ( from vllm.tokenizers.detokenizer_utils import (
TokenizerLike,
convert_prompt_ids_to_tokens, convert_prompt_ids_to_tokens,
detokenize_incrementally, detokenize_incrementally,
) )