Files
vllm/tests/tokenizers_/test_hf.py
2025-11-30 14:59:47 +08:00

44 lines
1.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pickle
from copy import deepcopy
import pytest
from transformers import AutoTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.hf import get_cached_tokenizer
@pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"])
def test_cached_tokenizer(model_id: str):
reference_tokenizer = AutoTokenizer.from_pretrained(
model_id, trust_remote_code=True
)
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
reference_tokenizer.add_special_tokens({"additional_special_tokens": ["<SEP>"]})
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
_check_consistency(cached_tokenizer, reference_tokenizer)
pickled_tokenizer = pickle.dumps(cached_tokenizer)
unpickled_tokenizer = pickle.loads(pickled_tokenizer)
_check_consistency(unpickled_tokenizer, reference_tokenizer)
def _check_consistency(target: TokenizerLike, expected: TokenizerLike):
assert isinstance(target, type(expected))
# Cached attributes
assert target.all_special_ids == expected.all_special_ids
assert target.all_special_tokens == expected.all_special_tokens
assert target.get_vocab() == expected.get_vocab()
assert len(target) == len(expected)
# Other attributes
assert getattr(target, "padding_side", None) == getattr(
expected, "padding_side", None
)
assert target.encode("prompt") == expected.encode("prompt")