Refactor example prompts fixture (#29854)

Signed-off-by: nwaughac@gmail.com
This commit is contained in:
Chukwuma Nwaugha
2025-12-05 06:44:32 +00:00
committed by GitHub
parent d698bb382d
commit 6e865b6a83

View File

@@ -27,7 +27,7 @@ import threading
from collections.abc import Generator
from contextlib import nullcontext
from enum import Enum
from typing import Any, Callable, TypedDict, TypeVar, cast
from typing import Any, Callable, TypedDict, TypeVar, cast, TYPE_CHECKING
import numpy as np
import pytest
@@ -67,6 +67,11 @@ from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils.collection_utils import is_list_of
from vllm.utils.torch_utils import set_default_torch_num_threads
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.generation.utils import GenerateOutput
logger = init_logger(__name__)
_TEST_DIR = os.path.dirname(__file__)
@@ -202,10 +207,7 @@ def dynamo_reset():
@pytest.fixture
def example_prompts() -> list[str]:
prompts = []
for filename in _TEST_PROMPTS:
prompts += _read_prompts(filename)
return prompts
return [prompt for filename in _TEST_PROMPTS for prompt in _read_prompts(filename)]
@pytest.fixture
@@ -224,10 +226,7 @@ class DecoderPromptType(Enum):
@pytest.fixture
def example_long_prompts() -> list[str]:
prompts = []
for filename in _LONG_PROMPTS:
prompts += _read_prompts(filename)
return prompts
return [prompt for filename in _LONG_PROMPTS for prompt in _read_prompts(filename)]
@pytest.fixture(scope="session")
@@ -353,10 +352,13 @@ class HfRunner:
trust_remote_code=trust_remote_code,
)
else:
model = auto_cls.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
**model_kwargs,
model = cast(
nn.Module,
auto_cls.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
**model_kwargs,
),
)
# in case some unquantized custom models are not in same dtype
@@ -374,10 +376,12 @@ class HfRunner:
self.model = model
if not skip_tokenizer_init:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
dtype=dtype,
trust_remote_code=trust_remote_code,
self.tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast" = (
AutoTokenizer.from_pretrained(
model_name,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
)
# don't put this import at the top level
@@ -495,7 +499,7 @@ class HfRunner:
outputs: list[tuple[list[list[int]], list[str]]] = []
for inputs in all_inputs:
output_ids = self.model.generate(
output_ids: torch.Tensor = self.model.generate(
**self.wrap_device(inputs),
use_cache=True,
**kwargs,
@@ -505,8 +509,7 @@ class HfRunner:
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
output_ids = output_ids.cpu().tolist()
outputs.append((output_ids, output_str))
outputs.append((output_ids.cpu().tolist(), output_str))
return outputs
def generate_greedy(
@@ -574,7 +577,7 @@ class HfRunner:
all_logprobs: list[list[torch.Tensor]] = []
for inputs in all_inputs:
output = self.model.generate(
output: "GenerateOutput" = self.model.generate(
**self.wrap_device(inputs),
use_cache=True,
do_sample=False,
@@ -656,7 +659,7 @@ class HfRunner:
all_output_strs: list[str] = []
for inputs in all_inputs:
output = self.model.generate(
output: "GenerateOutput" = self.model.generate(
**self.wrap_device(inputs),
use_cache=True,
do_sample=False,