mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[Frontend] Add vllm bench sweep to CLI (#27639)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,4 +5,4 @@ nav:
|
||||
- complete.md
|
||||
- run-batch.md
|
||||
- vllm bench:
|
||||
- bench/*.md
|
||||
- bench/**/*.md
|
||||
|
||||
9
docs/cli/bench/sweep/plot.md
Normal file
9
docs/cli/bench/sweep/plot.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# vllm bench sweep plot
|
||||
|
||||
## JSON CLI Arguments
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
|
||||
--8<-- "docs/argparse/bench_sweep_plot.md"
|
||||
9
docs/cli/bench/sweep/serve.md
Normal file
9
docs/cli/bench/sweep/serve.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# vllm bench sweep serve
|
||||
|
||||
## JSON CLI Arguments
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
|
||||
--8<-- "docs/argparse/bench_sweep_serve.md"
|
||||
9
docs/cli/bench/sweep/serve_sla.md
Normal file
9
docs/cli/bench/sweep/serve_sla.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# vllm bench sweep serve_sla
|
||||
|
||||
## JSON CLI Arguments
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
|
||||
--8<-- "docs/argparse/bench_sweep_serve_sla.md"
|
||||
@@ -1061,7 +1061,7 @@ Follow these steps to run the script:
|
||||
Example command:
|
||||
|
||||
```bash
|
||||
python -m vllm.benchmarks.sweep.serve \
|
||||
vllm bench sweep serve \
|
||||
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
|
||||
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
|
||||
--serve-params benchmarks/serve_hparams.json \
|
||||
@@ -1109,7 +1109,7 @@ For example, to ensure E2E latency within different target values for 99% of req
|
||||
Example command:
|
||||
|
||||
```bash
|
||||
python -m vllm.benchmarks.sweep.serve_sla \
|
||||
vllm bench sweep serve_sla \
|
||||
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
|
||||
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
|
||||
--serve-params benchmarks/serve_hparams.json \
|
||||
@@ -1138,7 +1138,7 @@ The algorithm for adjusting the SLA variable is as follows:
|
||||
Example command:
|
||||
|
||||
```bash
|
||||
python -m vllm.benchmarks.sweep.plot benchmarks/results/<timestamp> \
|
||||
vllm bench sweep plot benchmarks/results/<timestamp> \
|
||||
--var-x max_concurrency \
|
||||
--row-by random_input_len \
|
||||
--col-by random_output_len \
|
||||
|
||||
@@ -56,15 +56,20 @@ def auto_mock(module, attr, max_mocks=50):
|
||||
)
|
||||
|
||||
|
||||
latency = auto_mock("vllm.benchmarks", "latency")
|
||||
serve = auto_mock("vllm.benchmarks", "serve")
|
||||
throughput = auto_mock("vllm.benchmarks", "throughput")
|
||||
bench_latency = auto_mock("vllm.benchmarks", "latency")
|
||||
bench_serve = auto_mock("vllm.benchmarks", "serve")
|
||||
bench_sweep_plot = auto_mock("vllm.benchmarks.sweep.plot", "SweepPlotArgs")
|
||||
bench_sweep_serve = auto_mock("vllm.benchmarks.sweep.serve", "SweepServeArgs")
|
||||
bench_sweep_serve_sla = auto_mock(
|
||||
"vllm.benchmarks.sweep.serve_sla", "SweepServeSLAArgs"
|
||||
)
|
||||
bench_throughput = auto_mock("vllm.benchmarks", "throughput")
|
||||
AsyncEngineArgs = auto_mock("vllm.engine.arg_utils", "AsyncEngineArgs")
|
||||
EngineArgs = auto_mock("vllm.engine.arg_utils", "EngineArgs")
|
||||
ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand")
|
||||
CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand")
|
||||
cli_args = auto_mock("vllm.entrypoints.openai", "cli_args")
|
||||
run_batch = auto_mock("vllm.entrypoints.openai", "run_batch")
|
||||
openai_cli_args = auto_mock("vllm.entrypoints.openai", "cli_args")
|
||||
openai_run_batch = auto_mock("vllm.entrypoints.openai", "run_batch")
|
||||
FlexibleArgumentParser = auto_mock(
|
||||
"vllm.utils.argparse_utils", "FlexibleArgumentParser"
|
||||
)
|
||||
@@ -114,6 +119,9 @@ class MarkdownFormatter(HelpFormatter):
|
||||
self._markdown_output.append(f"{action.help}\n\n")
|
||||
|
||||
if (default := action.default) != SUPPRESS:
|
||||
# Make empty string defaults visible
|
||||
if default == "":
|
||||
default = '""'
|
||||
self._markdown_output.append(f"Default: `{default}`\n\n")
|
||||
|
||||
def format_help(self):
|
||||
@@ -150,17 +158,23 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||
|
||||
# Create parsers to document
|
||||
parsers = {
|
||||
# Engine args
|
||||
"engine_args": create_parser(EngineArgs.add_cli_args),
|
||||
"async_engine_args": create_parser(
|
||||
AsyncEngineArgs.add_cli_args, async_args_only=True
|
||||
),
|
||||
"serve": create_parser(cli_args.make_arg_parser),
|
||||
# CLI
|
||||
"serve": create_parser(openai_cli_args.make_arg_parser),
|
||||
"chat": create_parser(ChatCommand.add_cli_args),
|
||||
"complete": create_parser(CompleteCommand.add_cli_args),
|
||||
"bench_latency": create_parser(latency.add_cli_args),
|
||||
"bench_throughput": create_parser(throughput.add_cli_args),
|
||||
"bench_serve": create_parser(serve.add_cli_args),
|
||||
"run-batch": create_parser(run_batch.make_arg_parser),
|
||||
"run-batch": create_parser(openai_run_batch.make_arg_parser),
|
||||
# Benchmark CLI
|
||||
"bench_latency": create_parser(bench_latency.add_cli_args),
|
||||
"bench_serve": create_parser(bench_serve.add_cli_args),
|
||||
"bench_sweep_plot": create_parser(bench_sweep_plot.add_cli_args),
|
||||
"bench_sweep_serve": create_parser(bench_sweep_serve.add_cli_args),
|
||||
"bench_sweep_serve_sla": create_parser(bench_sweep_serve_sla.add_cli_args),
|
||||
"bench_throughput": create_parser(bench_throughput.add_cli_args),
|
||||
}
|
||||
|
||||
# Generate documentation for each parser
|
||||
|
||||
2
setup.py
2
setup.py
@@ -709,7 +709,7 @@ setup(
|
||||
ext_modules=ext_modules,
|
||||
install_requires=get_requirements(),
|
||||
extras_require={
|
||||
"bench": ["pandas", "datasets"],
|
||||
"bench": ["pandas", "matplotlib", "seaborn", "datasets"],
|
||||
"tensorizer": ["tensorizer==2.10.1"],
|
||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||
"runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"],
|
||||
|
||||
@@ -141,7 +141,7 @@ def attempt_to_make_names_unique(entries_and_traces):
|
||||
"""
|
||||
|
||||
|
||||
def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
def group_trace_by_operations(trace_df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
def is_rms_norm(op_name: str):
|
||||
if "rms_norm_kernel" in op_name:
|
||||
return True
|
||||
@@ -370,12 +370,12 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
|
||||
|
||||
def plot_trace_df(
|
||||
traces_df: pd.DataFrame,
|
||||
traces_df: "pd.DataFrame",
|
||||
plot_metric: str,
|
||||
plot_title: str,
|
||||
output: Path | None = None,
|
||||
):
|
||||
def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str:
|
||||
def get_phase_description(traces_df: "pd.DataFrame", phase: str) -> str:
|
||||
phase_df = traces_df.query(f'phase == "{phase}"')
|
||||
descs = phase_df["phase_desc"].to_list()
|
||||
assert all([desc == descs[0] for desc in descs])
|
||||
@@ -438,7 +438,7 @@ def main(
|
||||
top_k: int,
|
||||
json_nodes_to_fold: list[str],
|
||||
):
|
||||
def prepare_data(profile_json: dict, step_keys: list[str]) -> pd.DataFrame:
|
||||
def prepare_data(profile_json: dict, step_keys: list[str]) -> "pd.DataFrame":
|
||||
def get_entries_and_traces(key: str):
|
||||
entries_and_traces: list[tuple[Any, Any]] = []
|
||||
for root in profile_json[key]["summary_stats"]:
|
||||
@@ -449,8 +449,8 @@ def main(
|
||||
return entries_and_traces
|
||||
|
||||
def keep_only_top_entries(
|
||||
df: pd.DataFrame, metric: str, top_k: int = 9
|
||||
) -> pd.DataFrame:
|
||||
df: "pd.DataFrame", metric: str, top_k: int = 9
|
||||
) -> "pd.DataFrame":
|
||||
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, ["name"]] = "others"
|
||||
return df
|
||||
|
||||
|
||||
38
vllm/benchmarks/sweep/cli.py
Normal file
38
vllm/benchmarks/sweep/cli.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
|
||||
from .plot import SweepPlotArgs
|
||||
from .plot import main as plot_main
|
||||
from .serve import SweepServeArgs
|
||||
from .serve import main as serve_main
|
||||
from .serve_sla import SweepServeSLAArgs
|
||||
from .serve_sla import main as serve_sla_main
|
||||
|
||||
SUBCOMMANDS = (
|
||||
(SweepServeArgs, serve_main),
|
||||
(SweepServeSLAArgs, serve_sla_main),
|
||||
(SweepPlotArgs, plot_main),
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
subparsers = parser.add_subparsers(required=True, dest="sweep_type")
|
||||
|
||||
for cmd, entrypoint in SUBCOMMANDS:
|
||||
cmd_subparser = subparsers.add_parser(
|
||||
cmd.parser_name,
|
||||
description=cmd.parser_help,
|
||||
usage=f"vllm bench sweep {cmd.parser_name} [options]",
|
||||
)
|
||||
cmd_subparser.set_defaults(dispatch_function=entrypoint)
|
||||
cmd.add_cli_args(cmd_subparser)
|
||||
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
|
||||
subcmd=f"sweep {cmd.parser_name}"
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
args.dispatch_function(args)
|
||||
@@ -8,16 +8,24 @@ from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import ClassVar
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from vllm.utils.collection_utils import full_groupby
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
except ImportError:
|
||||
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
|
||||
pd = PlaceholderModule("pandas")
|
||||
seaborn = PlaceholderModule("seaborn")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotFilterBase(ABC):
|
||||
@@ -40,7 +48,7 @@ class PlotFilterBase(ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
"""Applies this filter to a DataFrame."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -48,7 +56,7 @@ class PlotFilterBase(ABC):
|
||||
@dataclass
|
||||
class PlotEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
try:
|
||||
target = float(self.target)
|
||||
except ValueError:
|
||||
@@ -60,28 +68,28 @@ class PlotEqualTo(PlotFilterBase):
|
||||
@dataclass
|
||||
class PlotLessThan(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] < float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotLessThanOrEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] <= float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotGreaterThan(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] > float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotGreaterThanOrEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] >= float(self.target)]
|
||||
|
||||
|
||||
@@ -103,7 +111,7 @@ class PlotFilters(list[PlotFilterBase]):
|
||||
|
||||
return cls(PlotFilterBase.parse_str(e) for e in s.split(","))
|
||||
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
for item in self:
|
||||
df = item.apply(df)
|
||||
|
||||
@@ -127,7 +135,7 @@ class PlotBinner:
|
||||
f"Valid operators are: {sorted(PLOT_BINNERS)}",
|
||||
)
|
||||
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
"""Applies this binner to a DataFrame."""
|
||||
df = df.copy()
|
||||
df[self.var] = df[self.var] // self.bin_size * self.bin_size
|
||||
@@ -147,7 +155,7 @@ class PlotBinners(list[PlotBinner]):
|
||||
|
||||
return cls(PlotBinner.parse_str(e) for e in s.split(","))
|
||||
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
for item in self:
|
||||
df = item.apply(df)
|
||||
|
||||
@@ -396,135 +404,177 @@ def plot(
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the results to plot, "
|
||||
"i.e., the `--output-dir` argument to the parameter sweep script.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="The directory to save the figures, relative to `OUTPUT_DIR`. "
|
||||
"By default, the same directory is used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate figure "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--row-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate row "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--col-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate column "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--curve-by",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A comma-separated list of variables, such that a separate curve "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-x",
|
||||
type=str,
|
||||
default="request_throughput",
|
||||
help="The variable for the x-axis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-y",
|
||||
type=str,
|
||||
default="p99_e2el_ms",
|
||||
help="The variable for the y-axis",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to filter by. "
|
||||
"This is useful to remove outliers. "
|
||||
"Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means "
|
||||
"plot only the points where `max_concurrency` is less than 1000 and "
|
||||
"`max_num_batched_tokens` is no greater than 4096.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bin-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to bin by. "
|
||||
"This is useful to avoid plotting points that are too close together. "
|
||||
"Example: `request_throughput%1` means "
|
||||
"use a bin size of 1 for the `request_throughput` variable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-x",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the x-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-y",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the y-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the information about each figure to plot, "
|
||||
"then exits without drawing them.",
|
||||
)
|
||||
@dataclass
|
||||
class SweepPlotArgs:
|
||||
output_dir: Path
|
||||
fig_dir: Path
|
||||
fig_by: list[str]
|
||||
row_by: list[str]
|
||||
col_by: list[str]
|
||||
curve_by: list[str]
|
||||
var_x: str
|
||||
var_y: str
|
||||
filter_by: PlotFilters
|
||||
bin_by: PlotBinners
|
||||
scale_x: str | None
|
||||
scale_y: str | None
|
||||
dry_run: bool
|
||||
|
||||
parser_name: ClassVar[str] = "plot"
|
||||
parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results."
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
curve_by = [] if not args.curve_by else args.curve_by.split(",")
|
||||
row_by = [] if not args.row_by else args.row_by.split(",")
|
||||
col_by = [] if not args.col_by else args.col_by.split(",")
|
||||
fig_by = [] if not args.fig_by else args.fig_by.split(",")
|
||||
|
||||
return cls(
|
||||
output_dir=output_dir,
|
||||
fig_dir=output_dir / args.fig_dir,
|
||||
fig_by=fig_by,
|
||||
row_by=row_by,
|
||||
col_by=col_by,
|
||||
curve_by=curve_by,
|
||||
var_x=args.var_x,
|
||||
var_y=args.var_y,
|
||||
filter_by=PlotFilters.parse_str(args.filter_by),
|
||||
bin_by=PlotBinners.parse_str(args.bin_by),
|
||||
scale_x=args.scale_x,
|
||||
scale_y=args.scale_y,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the results to plot, "
|
||||
"i.e., the `--output-dir` argument to the parameter sweep script.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="The directory to save the figures, relative to `OUTPUT_DIR`. "
|
||||
"By default, the same directory is used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate figure "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--row-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate row "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--col-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate column "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--curve-by",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A comma-separated list of variables, such that a separate curve "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-x",
|
||||
type=str,
|
||||
default="request_throughput",
|
||||
help="The variable for the x-axis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-y",
|
||||
type=str,
|
||||
default="p99_e2el_ms",
|
||||
help="The variable for the y-axis",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to filter by. "
|
||||
"This is useful to remove outliers. "
|
||||
"Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means "
|
||||
"plot only the points where `max_concurrency` is less than 1000 and "
|
||||
"`max_num_batched_tokens` is no greater than 4096.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bin-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to bin by. "
|
||||
"This is useful to avoid plotting points that are too close together. "
|
||||
"Example: `request_throughput%%1` means "
|
||||
"use a bin size of 1 for the `request_throughput` variable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-x",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the x-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-y",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the y-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the information about each figure to plot, "
|
||||
"then exits without drawing them.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
curve_by = [] if not args.curve_by else args.curve_by.split(",")
|
||||
row_by = [] if not args.row_by else args.row_by.split(",")
|
||||
col_by = [] if not args.col_by else args.col_by.split(",")
|
||||
fig_by = [] if not args.fig_by else args.fig_by.split(",")
|
||||
|
||||
plot(
|
||||
output_dir=output_dir,
|
||||
fig_dir=output_dir / args.fig_dir,
|
||||
fig_by=fig_by,
|
||||
row_by=row_by,
|
||||
col_by=col_by,
|
||||
curve_by=curve_by,
|
||||
def run_main(args: SweepPlotArgs):
|
||||
return plot(
|
||||
output_dir=args.output_dir,
|
||||
fig_dir=args.fig_dir,
|
||||
fig_by=args.fig_by,
|
||||
row_by=args.row_by,
|
||||
col_by=args.col_by,
|
||||
curve_by=args.curve_by,
|
||||
var_x=args.var_x,
|
||||
var_y=args.var_y,
|
||||
filter_by=PlotFilters.parse_str(args.filter_by),
|
||||
bin_by=PlotBinners.parse_str(args.bin_by),
|
||||
filter_by=args.filter_by,
|
||||
bin_by=args.bin_by,
|
||||
scale_x=args.scale_x,
|
||||
scale_y=args.scale_y,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepPlotArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Plot performance curves from parameter sweep results."
|
||||
)
|
||||
add_cli_args(parser)
|
||||
parser = argparse.ArgumentParser(description=SweepPlotArgs.parser_help)
|
||||
SweepPlotArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
|
||||
@@ -7,13 +7,19 @@ import shlex
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
import pandas as pd
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .server import ServerProcess
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_server(
|
||||
@@ -257,6 +263,9 @@ class SweepServeArgs:
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
|
||||
parser_name: ClassVar[str] = "serve"
|
||||
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
serve_cmd = shlex.split(args.serve_cmd)
|
||||
@@ -401,9 +410,7 @@ def main(args: argparse.Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run vLLM server benchmark under multiple settings."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description=SweepServeArgs.parser_help)
|
||||
SweepServeArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
|
||||
@@ -7,17 +7,23 @@ import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Literal, get_args
|
||||
from typing import ClassVar, Literal, get_args
|
||||
|
||||
import pandas as pd
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .serve import SweepServeArgs, run_benchmark, run_server
|
||||
from .server import ServerProcess
|
||||
from .sla_sweep import SLASweep, SLASweepItem
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
def _get_sla_base_path(
|
||||
output_dir: Path,
|
||||
@@ -399,6 +405,9 @@ class SweepServeSLAArgs(SweepServeArgs):
|
||||
sla_params: SLASweep
|
||||
sla_variable: SLAVariable
|
||||
|
||||
parser_name: ClassVar[str] = "serve_sla"
|
||||
parser_help: ClassVar[str] = "Tune a variable to meet SLAs under multiple settings."
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
|
||||
@@ -419,7 +428,8 @@ class SweepServeSLAArgs(SweepServeArgs):
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = super().add_cli_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
sla_group = parser.add_argument_group("sla options")
|
||||
sla_group.add_argument(
|
||||
"--sla-params",
|
||||
type=str,
|
||||
required=True,
|
||||
@@ -431,7 +441,7 @@ class SweepServeSLAArgs(SweepServeArgs):
|
||||
"the maximum `sla_variable` that satisfies the constraints for "
|
||||
"each combination of `serve_params`, `bench_params`, and `sla_params`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
sla_group.add_argument(
|
||||
"--sla-variable",
|
||||
type=str,
|
||||
choices=get_args(SLAVariable),
|
||||
@@ -476,9 +486,7 @@ def main(args: argparse.Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tune a variable to meet SLAs under multiple settings."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
|
||||
SweepServeSLAArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
|
||||
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
|
||||
from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand
|
||||
from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand
|
||||
|
||||
__all__: list[str] = [
|
||||
"BenchmarkLatencySubcommand",
|
||||
"BenchmarkServingSubcommand",
|
||||
"BenchmarkSweepSubcommand",
|
||||
"BenchmarkThroughputSubcommand",
|
||||
]
|
||||
|
||||
@@ -6,7 +6,7 @@ from vllm.entrypoints.cli.types import CLISubcommand
|
||||
|
||||
|
||||
class BenchmarkSubcommandBase(CLISubcommand):
|
||||
"""The base class of subcommands for vllm bench."""
|
||||
"""The base class of subcommands for `vllm bench`."""
|
||||
|
||||
help: str
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
|
||||
|
||||
class BenchmarkLatencySubcommand(BenchmarkSubcommandBase):
|
||||
"""The `latency` subcommand for vllm bench."""
|
||||
"""The `latency` subcommand for `vllm bench`."""
|
||||
|
||||
name = "latency"
|
||||
help = "Benchmark the latency of a single batch of requests."
|
||||
|
||||
@@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
|
||||
|
||||
class BenchmarkServingSubcommand(BenchmarkSubcommandBase):
|
||||
"""The `serve` subcommand for vllm bench."""
|
||||
"""The `serve` subcommand for `vllm bench`."""
|
||||
|
||||
name = "serve"
|
||||
help = "Benchmark the online serving throughput."
|
||||
|
||||
21
vllm/entrypoints/cli/benchmark/sweep.py
Normal file
21
vllm/entrypoints/cli/benchmark/sweep.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.benchmarks.sweep.cli import add_cli_args, main
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
|
||||
|
||||
class BenchmarkSweepSubcommand(BenchmarkSubcommandBase):
|
||||
"""The `sweep` subcommand for `vllm bench`."""
|
||||
|
||||
name = "sweep"
|
||||
help = "Benchmark for a parameter sweep."
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
|
||||
add_cli_args(parser)
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
main(args)
|
||||
@@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
|
||||
|
||||
class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase):
|
||||
"""The `throughput` subcommand for vllm bench."""
|
||||
"""The `throughput` subcommand for `vllm bench`."""
|
||||
|
||||
name = "throughput"
|
||||
help = "Benchmark offline inference throughput."
|
||||
|
||||
@@ -7,7 +7,6 @@ from collections.abc import Callable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Optional, TypeAlias
|
||||
|
||||
import pandas as pd
|
||||
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
|
||||
from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent
|
||||
from torch.autograd.profiler import FunctionEvent
|
||||
@@ -21,6 +20,12 @@ from vllm.profiler.utils import (
|
||||
event_torch_op_stack_trace,
|
||||
indent_string,
|
||||
)
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user