mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
Remove default values from InitVars so that they're not stored (#29859)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -108,7 +108,10 @@ def benchmark_batched_propose(args):
|
||||
device_config=DeviceConfig(device=current_platform.device_type),
|
||||
parallel_config=ParallelConfig(),
|
||||
load_config=LoadConfig(),
|
||||
scheduler_config=SchedulerConfig(),
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
),
|
||||
)
|
||||
|
||||
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
|
||||
|
||||
@@ -318,13 +318,18 @@ def test_attention_quant_pattern(
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(42)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
dtype=dtype,
|
||||
)
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
dtype=dtype,
|
||||
model_config=model_config,
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_num_seqs=1024,
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
),
|
||||
scheduler_config=SchedulerConfig(max_num_seqs=1024),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops_list,
|
||||
|
||||
@@ -33,14 +33,16 @@ def test_worker_apply_lora(qwen3_lora_files):
|
||||
lora_requests, lora_mapping
|
||||
)
|
||||
|
||||
model_config = ModelConfig(
|
||||
MODEL_PATH,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
max_model_len=127,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(
|
||||
MODEL_PATH,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
max_model_len=127,
|
||||
enforce_eager=True,
|
||||
),
|
||||
model_config=model_config,
|
||||
load_config=LoadConfig(
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
@@ -50,7 +52,14 @@ def test_worker_apply_lora(qwen3_lora_files):
|
||||
tensor_parallel_size=1,
|
||||
data_parallel_size=1,
|
||||
),
|
||||
scheduler_config=SchedulerConfig("generate", 32, 32, 32),
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
runner_type="generate",
|
||||
max_num_batched_tokens=32,
|
||||
max_num_seqs=32,
|
||||
max_num_partial_prefills=32,
|
||||
),
|
||||
device_config=DeviceConfig("cuda"),
|
||||
cache_config=CacheConfig(
|
||||
block_size=16,
|
||||
|
||||
@@ -6,12 +6,14 @@ from dataclasses import MISSING, Field, asdict, dataclass, field
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
ModelConfig,
|
||||
PoolerConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
update_config,
|
||||
)
|
||||
@@ -1095,3 +1097,14 @@ def test_vllm_config_explicit_overrides():
|
||||
# Other fields should still use defaults
|
||||
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
|
||||
|
||||
def test_scheduler_config_init():
|
||||
with pytest.raises(ValidationError):
|
||||
# Positional InitVars missing
|
||||
# (InitVars cannot have defaults otherwise they will become attributes)
|
||||
SchedulerConfig()
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
# InitVar does not become an attribute
|
||||
print(SchedulerConfig.default_factory().max_model_len)
|
||||
|
||||
@@ -185,6 +185,8 @@ def create_vllm_config(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
device_config = DeviceConfig()
|
||||
|
||||
@@ -1128,7 +1128,11 @@ def test_estimate_max_model_len(model_id, max_model_len, want_estimated_max_len)
|
||||
dtype="float16",
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens=32768)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_batched_tokens=32768,
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
@@ -1163,7 +1167,10 @@ def test_get_max_concurrency_for_kv_cache_config():
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_batched_tokens=1024, enable_chunked_prefill=True
|
||||
max_num_batched_tokens=1024,
|
||||
enable_chunked_prefill=True,
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
|
||||
@@ -1508,6 +1508,12 @@ def create_scheduler_with_priority(
|
||||
Returns:
|
||||
{class}`Scheduler` instance with priority scheduling
|
||||
"""
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
if max_model_len is None:
|
||||
max_model_len = max_num_batched_tokens
|
||||
scheduler_config = SchedulerConfig(
|
||||
@@ -1517,14 +1523,9 @@ def create_scheduler_with_priority(
|
||||
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||
enable_chunked_prefill=True,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
policy="priority", # Enable priority scheduling
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
|
||||
@@ -69,6 +69,13 @@ def create_scheduler(
|
||||
Returns:
|
||||
{class}`Scheduler` instance
|
||||
"""
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
)
|
||||
if max_model_len is None:
|
||||
max_model_len = max_num_batched_tokens
|
||||
scheduler_config = SchedulerConfig(
|
||||
@@ -79,13 +86,7 @@ def create_scheduler(
|
||||
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
async_scheduling=async_scheduling,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
cache_config = CacheConfig(
|
||||
|
||||
@@ -40,7 +40,9 @@ def _create_vllm_config(
|
||||
) -> MagicMock:
|
||||
mock_config = MagicMock(spec=VllmConfig)
|
||||
mock_config.compilation_config = compilation_config
|
||||
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
|
||||
mock_config.scheduler_config = SchedulerConfig.default_factory(
|
||||
max_num_seqs=max_num_seqs,
|
||||
)
|
||||
mock_config.parallel_config = ParallelConfig()
|
||||
mock_config.speculative_config = None # No speculative decoding
|
||||
if not lora_config:
|
||||
|
||||
@@ -484,12 +484,6 @@ def test_encoder_instance_zero_kv_cache(
|
||||
vision encoder, so they don't need KV cache for text generation.
|
||||
"""
|
||||
# Form vllm config
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
max_model_len=512,
|
||||
disable_hybrid_kv_cache_manager=True,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model="llava-hf/llava-1.5-7b-hf", # Multimodal model
|
||||
enforce_eager=True,
|
||||
@@ -497,6 +491,13 @@ def test_encoder_instance_zero_kv_cache(
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
max_model_len=512,
|
||||
disable_hybrid_kv_cache_manager=True,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=16,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
|
||||
@@ -92,18 +92,19 @@ def create_vllm_config(
|
||||
enable_permute_local_kv: bool = False,
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_model_len,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_model_len,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
|
||||
@@ -66,7 +66,10 @@ def _create_proposer(
|
||||
device_config=DeviceConfig(device=current_platform.device_type),
|
||||
parallel_config=ParallelConfig(),
|
||||
load_config=LoadConfig(),
|
||||
scheduler_config=SchedulerConfig(),
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
),
|
||||
)
|
||||
|
||||
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
|
||||
|
||||
@@ -51,7 +51,10 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
|
||||
device_config=DeviceConfig(device=current_platform.device_type),
|
||||
parallel_config=ParallelConfig(),
|
||||
load_config=LoadConfig(),
|
||||
scheduler_config=SchedulerConfig(),
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
),
|
||||
)
|
||||
|
||||
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
|
||||
|
||||
@@ -26,16 +26,17 @@ from vllm.v1.worker.tpu_model_runner import (
|
||||
|
||||
|
||||
def get_vllm_config():
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
max_model_len=512,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
dtype="bfloat16", # TPUs typically use bfloat16
|
||||
seed=42,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
max_model_len=512,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=16,
|
||||
gpu_memory_utilization=0.9,
|
||||
|
||||
@@ -79,16 +79,17 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
||||
|
||||
|
||||
def get_vllm_config():
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
max_model_len=512,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
max_model_len=512,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=BLOCK_SIZE,
|
||||
gpu_memory_utilization=0.9,
|
||||
@@ -784,14 +785,15 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model="ibm-granite/granite-4.0-tiny-preview",
|
||||
dtype="float16",
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
max_model_len=512,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model="ibm-granite/granite-4.0-tiny-preview",
|
||||
dtype="float16",
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=BLOCK_SIZE,
|
||||
|
||||
@@ -28,6 +28,19 @@ SchedulerPolicy = Literal["fcfs", "priority"]
|
||||
class SchedulerConfig:
|
||||
"""Scheduler configuration."""
|
||||
|
||||
max_model_len: InitVar[int]
|
||||
"""Maximum length of a sequence (including prompt and generated text).
|
||||
|
||||
Note: This is stored in the ModelConfig, and is used only here to
|
||||
provide fallbacks and validate other attributes."""
|
||||
|
||||
is_encoder_decoder: InitVar[bool]
|
||||
"""True if the model is an encoder-decoder model.
|
||||
|
||||
Note: This is stored in the ModelConfig, and is used only here to
|
||||
disable chunked prefill and prefix caching for encoder-decoder models.
|
||||
"""
|
||||
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
|
||||
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128
|
||||
|
||||
@@ -73,19 +86,6 @@ class SchedulerConfig:
|
||||
is_multimodal_model: bool = False
|
||||
"""True if the model is multimodal."""
|
||||
|
||||
max_model_len: InitVar[int] = 8192
|
||||
"""Maximum length of a sequence (including prompt and generated text).
|
||||
|
||||
Note: This is stored in the ModelConfig, and is used only here to
|
||||
provide fallbacks and validate other attributes."""
|
||||
|
||||
is_encoder_decoder: InitVar[bool] = False
|
||||
"""True if the model is an encoder-decoder model.
|
||||
|
||||
Note: This is stored in the ModelConfig, and is used only here to
|
||||
disable chunked prefill and prefix caching for encoder-decoder models.
|
||||
"""
|
||||
|
||||
# TODO (ywang96): Make this configurable.
|
||||
max_num_encoder_input_tokens: int = Field(init=False)
|
||||
"""Multimodal encoder compute budget, only used in V1.
|
||||
@@ -141,6 +141,17 @@ class SchedulerConfig:
|
||||
while a larger value (e.g., 10) reduces host overhead and may increase throughput
|
||||
by batching multiple tokens before sending."""
|
||||
|
||||
@staticmethod
|
||||
def default_factory(**kwargs):
|
||||
"""
|
||||
Factory method to create `SchedulerConfig` with default values for `InitVar`s.
|
||||
"""
|
||||
if "max_model_len" not in kwargs:
|
||||
kwargs["max_model_len"] = 8192
|
||||
if "is_encoder_decoder" not in kwargs:
|
||||
kwargs["is_encoder_decoder"] = False
|
||||
return SchedulerConfig(**kwargs)
|
||||
|
||||
def get_scheduler_cls(self) -> type["SchedulerInterface"]:
|
||||
if self.scheduler_cls is None:
|
||||
if self.async_scheduling:
|
||||
@@ -284,8 +295,3 @@ class SchedulerConfig:
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __getattribute__(self, name: str) -> Any:
|
||||
if name == "max_model_len" or name == "is_encoder_decoder":
|
||||
raise AttributeError(f"{name} is an init-only parameter. ")
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
@@ -170,7 +170,9 @@ class VllmConfig:
|
||||
"""Cache configuration."""
|
||||
parallel_config: ParallelConfig = Field(default_factory=ParallelConfig)
|
||||
"""Parallel configuration."""
|
||||
scheduler_config: SchedulerConfig = Field(default_factory=SchedulerConfig)
|
||||
scheduler_config: SchedulerConfig = Field(
|
||||
default_factory=SchedulerConfig.default_factory,
|
||||
)
|
||||
"""Scheduler configuration."""
|
||||
device_config: DeviceConfig = Field(default_factory=DeviceConfig)
|
||||
"""Device configuration."""
|
||||
|
||||
Reference in New Issue
Block a user