[Frontend] Add vllm bench sweep to CLI (#27639)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Cyrus Leung
2025-10-29 20:59:48 +08:00
committed by GitHub
parent 9a0d2f0d92
commit ecca3fee76
19 changed files with 340 additions and 168 deletions

View File

@@ -5,4 +5,4 @@ nav:
- complete.md
- run-batch.md
- vllm bench:
- bench/*.md
- bench/**/*.md

View File

@@ -0,0 +1,9 @@
# vllm bench sweep plot
## JSON CLI Arguments
--8<-- "docs/cli/json_tip.inc.md"
## Options
--8<-- "docs/argparse/bench_sweep_plot.md"

View File

@@ -0,0 +1,9 @@
# vllm bench sweep serve
## JSON CLI Arguments
--8<-- "docs/cli/json_tip.inc.md"
## Options
--8<-- "docs/argparse/bench_sweep_serve.md"

View File

@@ -0,0 +1,9 @@
# vllm bench sweep serve_sla
## JSON CLI Arguments
--8<-- "docs/cli/json_tip.inc.md"
## Options
--8<-- "docs/argparse/bench_sweep_serve_sla.md"

View File

@@ -1061,7 +1061,7 @@ Follow these steps to run the script:
Example command:
```bash
python -m vllm.benchmarks.sweep.serve \
vllm bench sweep serve \
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
--serve-params benchmarks/serve_hparams.json \
@@ -1109,7 +1109,7 @@ For example, to ensure E2E latency within different target values for 99% of req
Example command:
```bash
python -m vllm.benchmarks.sweep.serve_sla \
vllm bench sweep serve_sla \
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
--serve-params benchmarks/serve_hparams.json \
@@ -1138,7 +1138,7 @@ The algorithm for adjusting the SLA variable is as follows:
Example command:
```bash
python -m vllm.benchmarks.sweep.plot benchmarks/results/<timestamp> \
vllm bench sweep plot benchmarks/results/<timestamp> \
--var-x max_concurrency \
--row-by random_input_len \
--col-by random_output_len \

View File

@@ -56,15 +56,20 @@ def auto_mock(module, attr, max_mocks=50):
)
latency = auto_mock("vllm.benchmarks", "latency")
serve = auto_mock("vllm.benchmarks", "serve")
throughput = auto_mock("vllm.benchmarks", "throughput")
bench_latency = auto_mock("vllm.benchmarks", "latency")
bench_serve = auto_mock("vllm.benchmarks", "serve")
bench_sweep_plot = auto_mock("vllm.benchmarks.sweep.plot", "SweepPlotArgs")
bench_sweep_serve = auto_mock("vllm.benchmarks.sweep.serve", "SweepServeArgs")
bench_sweep_serve_sla = auto_mock(
"vllm.benchmarks.sweep.serve_sla", "SweepServeSLAArgs"
)
bench_throughput = auto_mock("vllm.benchmarks", "throughput")
AsyncEngineArgs = auto_mock("vllm.engine.arg_utils", "AsyncEngineArgs")
EngineArgs = auto_mock("vllm.engine.arg_utils", "EngineArgs")
ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand")
CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand")
cli_args = auto_mock("vllm.entrypoints.openai", "cli_args")
run_batch = auto_mock("vllm.entrypoints.openai", "run_batch")
openai_cli_args = auto_mock("vllm.entrypoints.openai", "cli_args")
openai_run_batch = auto_mock("vllm.entrypoints.openai", "run_batch")
FlexibleArgumentParser = auto_mock(
"vllm.utils.argparse_utils", "FlexibleArgumentParser"
)
@@ -114,6 +119,9 @@ class MarkdownFormatter(HelpFormatter):
self._markdown_output.append(f"{action.help}\n\n")
if (default := action.default) != SUPPRESS:
# Make empty string defaults visible
if default == "":
default = '""'
self._markdown_output.append(f"Default: `{default}`\n\n")
def format_help(self):
@@ -150,17 +158,23 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
# Create parsers to document
parsers = {
# Engine args
"engine_args": create_parser(EngineArgs.add_cli_args),
"async_engine_args": create_parser(
AsyncEngineArgs.add_cli_args, async_args_only=True
),
"serve": create_parser(cli_args.make_arg_parser),
# CLI
"serve": create_parser(openai_cli_args.make_arg_parser),
"chat": create_parser(ChatCommand.add_cli_args),
"complete": create_parser(CompleteCommand.add_cli_args),
"bench_latency": create_parser(latency.add_cli_args),
"bench_throughput": create_parser(throughput.add_cli_args),
"bench_serve": create_parser(serve.add_cli_args),
"run-batch": create_parser(run_batch.make_arg_parser),
"run-batch": create_parser(openai_run_batch.make_arg_parser),
# Benchmark CLI
"bench_latency": create_parser(bench_latency.add_cli_args),
"bench_serve": create_parser(bench_serve.add_cli_args),
"bench_sweep_plot": create_parser(bench_sweep_plot.add_cli_args),
"bench_sweep_serve": create_parser(bench_sweep_serve.add_cli_args),
"bench_sweep_serve_sla": create_parser(bench_sweep_serve_sla.add_cli_args),
"bench_throughput": create_parser(bench_throughput.add_cli_args),
}
# Generate documentation for each parser

View File

@@ -709,7 +709,7 @@ setup(
ext_modules=ext_modules,
install_requires=get_requirements(),
extras_require={
"bench": ["pandas", "datasets"],
"bench": ["pandas", "matplotlib", "seaborn", "datasets"],
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"],

View File

@@ -141,7 +141,7 @@ def attempt_to_make_names_unique(entries_and_traces):
"""
def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
def group_trace_by_operations(trace_df: "pd.DataFrame") -> "pd.DataFrame":
def is_rms_norm(op_name: str):
if "rms_norm_kernel" in op_name:
return True
@@ -370,12 +370,12 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
def plot_trace_df(
traces_df: pd.DataFrame,
traces_df: "pd.DataFrame",
plot_metric: str,
plot_title: str,
output: Path | None = None,
):
def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str:
def get_phase_description(traces_df: "pd.DataFrame", phase: str) -> str:
phase_df = traces_df.query(f'phase == "{phase}"')
descs = phase_df["phase_desc"].to_list()
assert all([desc == descs[0] for desc in descs])
@@ -438,7 +438,7 @@ def main(
top_k: int,
json_nodes_to_fold: list[str],
):
def prepare_data(profile_json: dict, step_keys: list[str]) -> pd.DataFrame:
def prepare_data(profile_json: dict, step_keys: list[str]) -> "pd.DataFrame":
def get_entries_and_traces(key: str):
entries_and_traces: list[tuple[Any, Any]] = []
for root in profile_json[key]["summary_stats"]:
@@ -449,8 +449,8 @@ def main(
return entries_and_traces
def keep_only_top_entries(
df: pd.DataFrame, metric: str, top_k: int = 9
) -> pd.DataFrame:
df: "pd.DataFrame", metric: str, top_k: int = 9
) -> "pd.DataFrame":
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, ["name"]] = "others"
return df

View File

@@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
from .plot import SweepPlotArgs
from .plot import main as plot_main
from .serve import SweepServeArgs
from .serve import main as serve_main
from .serve_sla import SweepServeSLAArgs
from .serve_sla import main as serve_sla_main
SUBCOMMANDS = (
(SweepServeArgs, serve_main),
(SweepServeSLAArgs, serve_sla_main),
(SweepPlotArgs, plot_main),
)
def add_cli_args(parser: argparse.ArgumentParser):
subparsers = parser.add_subparsers(required=True, dest="sweep_type")
for cmd, entrypoint in SUBCOMMANDS:
cmd_subparser = subparsers.add_parser(
cmd.parser_name,
description=cmd.parser_help,
usage=f"vllm bench sweep {cmd.parser_name} [options]",
)
cmd_subparser.set_defaults(dispatch_function=entrypoint)
cmd.add_cli_args(cmd_subparser)
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
subcmd=f"sweep {cmd.parser_name}"
)
def main(args: argparse.Namespace):
args.dispatch_function(args)

View File

@@ -8,16 +8,24 @@ from dataclasses import dataclass
from functools import partial
from pathlib import Path
from types import TracebackType
from typing import ClassVar
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from typing_extensions import Self, override
from vllm.utils.collection_utils import full_groupby
from vllm.utils.import_utils import PlaceholderModule
from .utils import sanitize_filename
try:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
except ImportError:
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
pd = PlaceholderModule("pandas")
seaborn = PlaceholderModule("seaborn")
@dataclass
class PlotFilterBase(ABC):
@@ -40,7 +48,7 @@ class PlotFilterBase(ABC):
)
@abstractmethod
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
"""Applies this filter to a DataFrame."""
raise NotImplementedError
@@ -48,7 +56,7 @@ class PlotFilterBase(ABC):
@dataclass
class PlotEqualTo(PlotFilterBase):
@override
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
try:
target = float(self.target)
except ValueError:
@@ -60,28 +68,28 @@ class PlotEqualTo(PlotFilterBase):
@dataclass
class PlotLessThan(PlotFilterBase):
@override
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] < float(self.target)]
@dataclass
class PlotLessThanOrEqualTo(PlotFilterBase):
@override
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] <= float(self.target)]
@dataclass
class PlotGreaterThan(PlotFilterBase):
@override
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] > float(self.target)]
@dataclass
class PlotGreaterThanOrEqualTo(PlotFilterBase):
@override
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] >= float(self.target)]
@@ -103,7 +111,7 @@ class PlotFilters(list[PlotFilterBase]):
return cls(PlotFilterBase.parse_str(e) for e in s.split(","))
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
for item in self:
df = item.apply(df)
@@ -127,7 +135,7 @@ class PlotBinner:
f"Valid operators are: {sorted(PLOT_BINNERS)}",
)
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
"""Applies this binner to a DataFrame."""
df = df.copy()
df[self.var] = df[self.var] // self.bin_size * self.bin_size
@@ -147,7 +155,7 @@ class PlotBinners(list[PlotBinner]):
return cls(PlotBinner.parse_str(e) for e in s.split(","))
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
for item in self:
df = item.apply(df)
@@ -396,135 +404,177 @@ def plot(
)
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"OUTPUT_DIR",
type=str,
default="results",
help="The directory containing the results to plot, "
"i.e., the `--output-dir` argument to the parameter sweep script.",
)
parser.add_argument(
"--fig-dir",
type=str,
default="",
help="The directory to save the figures, relative to `OUTPUT_DIR`. "
"By default, the same directory is used.",
)
parser.add_argument(
"--fig-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate figure "
"is created for each combination of these variables.",
)
parser.add_argument(
"--row-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate row "
"is created for each combination of these variables.",
)
parser.add_argument(
"--col-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate column "
"is created for each combination of these variables.",
)
parser.add_argument(
"--curve-by",
type=str,
default=None,
help="A comma-separated list of variables, such that a separate curve "
"is created for each combination of these variables.",
)
parser.add_argument(
"--var-x",
type=str,
default="request_throughput",
help="The variable for the x-axis.",
)
parser.add_argument(
"--var-y",
type=str,
default="p99_e2el_ms",
help="The variable for the y-axis",
)
parser.add_argument(
"--filter-by",
type=str,
default="",
help="A comma-separated list of statements indicating values to filter by. "
"This is useful to remove outliers. "
"Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means "
"plot only the points where `max_concurrency` is less than 1000 and "
"`max_num_batched_tokens` is no greater than 4096.",
)
parser.add_argument(
"--bin-by",
type=str,
default="",
help="A comma-separated list of statements indicating values to bin by. "
"This is useful to avoid plotting points that are too close together. "
"Example: `request_throughput%1` means "
"use a bin size of 1 for the `request_throughput` variable.",
)
parser.add_argument(
"--scale-x",
type=str,
default=None,
help="The scale to use for the x-axis. "
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--scale-y",
type=str,
default=None,
help="The scale to use for the y-axis. "
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the information about each figure to plot, "
"then exits without drawing them.",
)
@dataclass
class SweepPlotArgs:
output_dir: Path
fig_dir: Path
fig_by: list[str]
row_by: list[str]
col_by: list[str]
curve_by: list[str]
var_x: str
var_y: str
filter_by: PlotFilters
bin_by: PlotBinners
scale_x: str | None
scale_y: str | None
dry_run: bool
parser_name: ClassVar[str] = "plot"
parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results."
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
curve_by = [] if not args.curve_by else args.curve_by.split(",")
row_by = [] if not args.row_by else args.row_by.split(",")
col_by = [] if not args.col_by else args.col_by.split(",")
fig_by = [] if not args.fig_by else args.fig_by.split(",")
return cls(
output_dir=output_dir,
fig_dir=output_dir / args.fig_dir,
fig_by=fig_by,
row_by=row_by,
col_by=col_by,
curve_by=curve_by,
var_x=args.var_x,
var_y=args.var_y,
filter_by=PlotFilters.parse_str(args.filter_by),
bin_by=PlotBinners.parse_str(args.bin_by),
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"OUTPUT_DIR",
type=str,
default="results",
help="The directory containing the results to plot, "
"i.e., the `--output-dir` argument to the parameter sweep script.",
)
parser.add_argument(
"--fig-dir",
type=str,
default="",
help="The directory to save the figures, relative to `OUTPUT_DIR`. "
"By default, the same directory is used.",
)
parser.add_argument(
"--fig-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate figure "
"is created for each combination of these variables.",
)
parser.add_argument(
"--row-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate row "
"is created for each combination of these variables.",
)
parser.add_argument(
"--col-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate column "
"is created for each combination of these variables.",
)
parser.add_argument(
"--curve-by",
type=str,
default=None,
help="A comma-separated list of variables, such that a separate curve "
"is created for each combination of these variables.",
)
parser.add_argument(
"--var-x",
type=str,
default="request_throughput",
help="The variable for the x-axis.",
)
parser.add_argument(
"--var-y",
type=str,
default="p99_e2el_ms",
help="The variable for the y-axis",
)
parser.add_argument(
"--filter-by",
type=str,
default="",
help="A comma-separated list of statements indicating values to filter by. "
"This is useful to remove outliers. "
"Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means "
"plot only the points where `max_concurrency` is less than 1000 and "
"`max_num_batched_tokens` is no greater than 4096.",
)
parser.add_argument(
"--bin-by",
type=str,
default="",
help="A comma-separated list of statements indicating values to bin by. "
"This is useful to avoid plotting points that are too close together. "
"Example: `request_throughput%%1` means "
"use a bin size of 1 for the `request_throughput` variable.",
)
parser.add_argument(
"--scale-x",
type=str,
default=None,
help="The scale to use for the x-axis. "
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--scale-y",
type=str,
default=None,
help="The scale to use for the y-axis. "
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the information about each figure to plot, "
"then exits without drawing them.",
)
return parser
def main(args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
curve_by = [] if not args.curve_by else args.curve_by.split(",")
row_by = [] if not args.row_by else args.row_by.split(",")
col_by = [] if not args.col_by else args.col_by.split(",")
fig_by = [] if not args.fig_by else args.fig_by.split(",")
plot(
output_dir=output_dir,
fig_dir=output_dir / args.fig_dir,
fig_by=fig_by,
row_by=row_by,
col_by=col_by,
curve_by=curve_by,
def run_main(args: SweepPlotArgs):
return plot(
output_dir=args.output_dir,
fig_dir=args.fig_dir,
fig_by=args.fig_by,
row_by=args.row_by,
col_by=args.col_by,
curve_by=args.curve_by,
var_x=args.var_x,
var_y=args.var_y,
filter_by=PlotFilters.parse_str(args.filter_by),
bin_by=PlotBinners.parse_str(args.bin_by),
filter_by=args.filter_by,
bin_by=args.bin_by,
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
)
def main(args: argparse.Namespace):
run_main(SweepPlotArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Plot performance curves from parameter sweep results."
)
add_cli_args(parser)
parser = argparse.ArgumentParser(description=SweepPlotArgs.parser_help)
SweepPlotArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -7,13 +7,19 @@ import shlex
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import ClassVar
import pandas as pd
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .server import ServerProcess
from .utils import sanitize_filename
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
@contextlib.contextmanager
def run_server(
@@ -257,6 +263,9 @@ class SweepServeArgs:
dry_run: bool
resume: str | None
parser_name: ClassVar[str] = "serve"
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
serve_cmd = shlex.split(args.serve_cmd)
@@ -401,9 +410,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run vLLM server benchmark under multiple settings."
)
parser = argparse.ArgumentParser(description=SweepServeArgs.parser_help)
SweepServeArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -7,17 +7,23 @@ import math
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import Literal, get_args
from typing import ClassVar, Literal, get_args
import pandas as pd
from typing_extensions import assert_never
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .serve import SweepServeArgs, run_benchmark, run_server
from .server import ServerProcess
from .sla_sweep import SLASweep, SLASweepItem
from .utils import sanitize_filename
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
def _get_sla_base_path(
output_dir: Path,
@@ -399,6 +405,9 @@ class SweepServeSLAArgs(SweepServeArgs):
sla_params: SLASweep
sla_variable: SLAVariable
parser_name: ClassVar[str] = "serve_sla"
parser_help: ClassVar[str] = "Tune a variable to meet SLAs under multiple settings."
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
@@ -419,7 +428,8 @@ class SweepServeSLAArgs(SweepServeArgs):
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = super().add_cli_args(parser)
parser.add_argument(
sla_group = parser.add_argument_group("sla options")
sla_group.add_argument(
"--sla-params",
type=str,
required=True,
@@ -431,7 +441,7 @@ class SweepServeSLAArgs(SweepServeArgs):
"the maximum `sla_variable` that satisfies the constraints for "
"each combination of `serve_params`, `bench_params`, and `sla_params`.",
)
parser.add_argument(
sla_group.add_argument(
"--sla-variable",
type=str,
choices=get_args(SLAVariable),
@@ -476,9 +486,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Tune a variable to meet SLAs under multiple settings."
)
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
SweepServeSLAArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -2,10 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand
from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand
__all__: list[str] = [
"BenchmarkLatencySubcommand",
"BenchmarkServingSubcommand",
"BenchmarkSweepSubcommand",
"BenchmarkThroughputSubcommand",
]

View File

@@ -6,7 +6,7 @@ from vllm.entrypoints.cli.types import CLISubcommand
class BenchmarkSubcommandBase(CLISubcommand):
"""The base class of subcommands for vllm bench."""
"""The base class of subcommands for `vllm bench`."""
help: str

View File

@@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkLatencySubcommand(BenchmarkSubcommandBase):
"""The `latency` subcommand for vllm bench."""
"""The `latency` subcommand for `vllm bench`."""
name = "latency"
help = "Benchmark the latency of a single batch of requests."

View File

@@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkServingSubcommand(BenchmarkSubcommandBase):
"""The `serve` subcommand for vllm bench."""
"""The `serve` subcommand for `vllm bench`."""
name = "serve"
help = "Benchmark the online serving throughput."

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.benchmarks.sweep.cli import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkSweepSubcommand(BenchmarkSubcommandBase):
"""The `sweep` subcommand for `vllm bench`."""
name = "sweep"
help = "Benchmark for a parameter sweep."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)

View File

@@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase):
"""The `throughput` subcommand for vllm bench."""
"""The `throughput` subcommand for `vllm bench`."""
name = "throughput"
help = "Benchmark offline inference throughput."

View File

@@ -7,7 +7,6 @@ from collections.abc import Callable
from dataclasses import asdict, dataclass, field
from typing import Any, Optional, TypeAlias
import pandas as pd
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent
from torch.autograd.profiler import FunctionEvent
@@ -21,6 +20,12 @@ from vllm.profiler.utils import (
event_torch_op_stack_trace,
indent_string,
)
from vllm.utils.import_utils import PlaceholderModule
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
@dataclass