mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
Refactor example prompts fixture (#29854)
Signed-off-by: nwaughac@gmail.com
This commit is contained in:
@@ -27,7 +27,7 @@ import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, TypedDict, TypeVar, cast
|
||||
from typing import Any, Callable, TypedDict, TypeVar, cast, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -67,6 +67,11 @@ from vllm.transformers_utils.utils import maybe_model_redirect
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from transformers.generation.utils import GenerateOutput
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_TEST_DIR = os.path.dirname(__file__)
|
||||
@@ -202,10 +207,7 @@ def dynamo_reset():
|
||||
|
||||
@pytest.fixture
|
||||
def example_prompts() -> list[str]:
|
||||
prompts = []
|
||||
for filename in _TEST_PROMPTS:
|
||||
prompts += _read_prompts(filename)
|
||||
return prompts
|
||||
return [prompt for filename in _TEST_PROMPTS for prompt in _read_prompts(filename)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -224,10 +226,7 @@ class DecoderPromptType(Enum):
|
||||
|
||||
@pytest.fixture
|
||||
def example_long_prompts() -> list[str]:
|
||||
prompts = []
|
||||
for filename in _LONG_PROMPTS:
|
||||
prompts += _read_prompts(filename)
|
||||
return prompts
|
||||
return [prompt for filename in _LONG_PROMPTS for prompt in _read_prompts(filename)]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -353,10 +352,13 @@ class HfRunner:
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
model = auto_cls.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**model_kwargs,
|
||||
model = cast(
|
||||
nn.Module,
|
||||
auto_cls.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**model_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
# in case some unquantized custom models are not in same dtype
|
||||
@@ -374,10 +376,12 @@ class HfRunner:
|
||||
self.model = model
|
||||
|
||||
if not skip_tokenizer_init:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
self.tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast" = (
|
||||
AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
)
|
||||
|
||||
# don't put this import at the top level
|
||||
@@ -495,7 +499,7 @@ class HfRunner:
|
||||
|
||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||
for inputs in all_inputs:
|
||||
output_ids = self.model.generate(
|
||||
output_ids: torch.Tensor = self.model.generate(
|
||||
**self.wrap_device(inputs),
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
@@ -505,8 +509,7 @@ class HfRunner:
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
output_ids = output_ids.cpu().tolist()
|
||||
outputs.append((output_ids, output_str))
|
||||
outputs.append((output_ids.cpu().tolist(), output_str))
|
||||
return outputs
|
||||
|
||||
def generate_greedy(
|
||||
@@ -574,7 +577,7 @@ class HfRunner:
|
||||
|
||||
all_logprobs: list[list[torch.Tensor]] = []
|
||||
for inputs in all_inputs:
|
||||
output = self.model.generate(
|
||||
output: "GenerateOutput" = self.model.generate(
|
||||
**self.wrap_device(inputs),
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
@@ -656,7 +659,7 @@ class HfRunner:
|
||||
all_output_strs: list[str] = []
|
||||
|
||||
for inputs in all_inputs:
|
||||
output = self.model.generate(
|
||||
output: "GenerateOutput" = self.model.generate(
|
||||
**self.wrap_device(inputs),
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
|
||||
Reference in New Issue
Block a user