mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[Feat] Drop-in Torch CUDA Profiler (#27841)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
77d702a22b
commit
975676d174
@@ -39,7 +39,7 @@ Refer to [examples/offline_inference/simple_profiling.py](../../examples/offline
|
||||
|
||||
```bash
|
||||
VLLM_TORCH_PROFILER_DIR=./vllm_profile \
|
||||
vllm serve meta-llama/Meta-Llama-3-70B
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||
```
|
||||
|
||||
vllm bench command:
|
||||
@@ -47,7 +47,7 @@ vllm bench command:
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model meta-llama/Meta-Llama-3-70B \
|
||||
--model meta-llama/Llama-3.1-8B-Instruct \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path sharegpt.json \
|
||||
--profile \
|
||||
@@ -70,18 +70,21 @@ apt update
|
||||
apt install nsight-systems-cli
|
||||
```
|
||||
|
||||
### Example commands and usage
|
||||
!!! tip
|
||||
When profiling with `nsys`, it is advisable to set the environment variable `VLLM_WORKER_MULTIPROC_METHOD=spawn`. The default is to use the `fork` method instead of `spawn`. More information on the topic can be found in the [Nsight Systems release notes](https://docs.nvidia.com/nsight-systems/ReleaseNotes/index.html#general-issues).
|
||||
|
||||
When profiling with `nsys`, it is advisable to set the environment variable `VLLM_WORKER_MULTIPROC_METHOD=spawn`. The default is to use the `fork` method instead of `spawn`. More information on the topic can be found in the [Nsight Systems release notes](https://docs.nvidia.com/nsight-systems/ReleaseNotes/index.html#general-issues).
|
||||
The Nsight Systems profiler can be launched with `nsys profile ...`, with a few recommended flags for vLLM: `--trace-fork-before-exec=true --cuda-graph-trace=node`.
|
||||
|
||||
### Example commands and usage
|
||||
|
||||
#### Offline Inference
|
||||
|
||||
For basic usage, you can just append `nsys profile -o report.nsys-rep --trace-fork-before-exec=true --cuda-graph-trace=node` before any existing script you would run for offline inference.
|
||||
For basic usage, you can just append the profiling command before any existing script you would run for offline inference.
|
||||
|
||||
The following is an example using the `vllm bench latency` script:
|
||||
|
||||
```bash
|
||||
nsys profile -o report.nsys-rep \
|
||||
nsys profile \
|
||||
--trace-fork-before-exec=true \
|
||||
--cuda-graph-trace=node \
|
||||
vllm bench latency \
|
||||
@@ -95,40 +98,29 @@ vllm bench latency \
|
||||
|
||||
#### OpenAI Server
|
||||
|
||||
To profile the server, you will want to prepend your `vllm serve` command with `nsys profile` just like for offline inference, however you must specify `--delay XX --duration YY` parameters according to the needs of your benchmark. After the duration time has been used up, the server will be killed.
|
||||
To profile the server, you will want to prepend your `vllm serve` command with `nsys profile` just like for offline inference, but you will need to specify a few other arguments to enable dynamic capture similarly to the Torch Profiler:
|
||||
|
||||
```bash
|
||||
# server
|
||||
nsys profile -o report.nsys-rep \
|
||||
VLLM_TORCH_CUDA_PROFILE=1 \
|
||||
nsys profile \
|
||||
--trace-fork-before-exec=true \
|
||||
--cuda-graph-trace=node \
|
||||
--delay 30 \
|
||||
--duration 60 \
|
||||
--capture-range=cudaProfilerApi \
|
||||
--capture-range-end repeat \
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||
|
||||
# client
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model meta-llama/Llama-3.1-8B-Instruct \
|
||||
--num-prompts 1 \
|
||||
--dataset-name random \
|
||||
--random-input 1024 \
|
||||
--random-output 512
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path sharegpt.json \
|
||||
--profile \
|
||||
--num-prompts 2
|
||||
```
|
||||
|
||||
In practice, you should set the `--duration` argument to a large value. Whenever you want the server to stop profiling, run:
|
||||
|
||||
```bash
|
||||
nsys sessions list
|
||||
```
|
||||
|
||||
to get the session id in the form of `profile-XXXXX`, then run:
|
||||
|
||||
```bash
|
||||
nsys stop --session=profile-XXXXX
|
||||
```
|
||||
|
||||
to manually kill the profiler and generate your `nsys-rep` report.
|
||||
With `--profile`, vLLM will capture a profile for each run of `vllm bench serve`. Once the server is killed, the profiles will all be saved.
|
||||
|
||||
#### Analysis
|
||||
|
||||
|
||||
@@ -1280,10 +1280,16 @@ async def invocations(raw_request: Request):
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning(
|
||||
logger.warning_once(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
)
|
||||
elif envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
logger.warning_once(
|
||||
"CUDA Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
)
|
||||
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile(raw_request: Request):
|
||||
|
||||
@@ -87,6 +87,7 @@ if TYPE_CHECKING:
|
||||
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
|
||||
VLLM_PLUGINS: list[str] | None = None
|
||||
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
|
||||
VLLM_TORCH_CUDA_PROFILE: bool = False
|
||||
VLLM_TORCH_PROFILER_DIR: str | None = None
|
||||
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
|
||||
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
|
||||
@@ -815,6 +816,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv(
|
||||
"VLLM_LORA_RESOLVER_CACHE_DIR", None
|
||||
),
|
||||
# Enables torch CUDA profiling if set.
|
||||
# On NVIDIA GPUs, this will start/stop cudaProfilerApi when triggered.
|
||||
"VLLM_TORCH_CUDA_PROFILE": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_CUDA_PROFILE", "0") != "0"
|
||||
),
|
||||
# Enables torch profiler if set.
|
||||
# Both AsyncLLM's CPU traces as well as workers'
|
||||
# traces (CPU & GPU) will be saved under this directory.
|
||||
|
||||
37
vllm/profiler/gpu_profiler.py
Normal file
37
vllm/profiler/gpu_profiler.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CudaProfilerWrapper:
|
||||
def __init__(self) -> None:
|
||||
self._profiler_running = False
|
||||
# Note: lazy import to avoid dependency issues if CUDA is not available.
|
||||
import torch.cuda.profiler as cuda_profiler
|
||||
|
||||
self._cuda_profiler = cuda_profiler
|
||||
|
||||
def start(self) -> None:
|
||||
try:
|
||||
self._cuda_profiler.start()
|
||||
self._profiler_running = True
|
||||
logger.info_once("Started CUDA profiler")
|
||||
except Exception as e:
|
||||
logger.warning_once("Failed to start CUDA profiler: %s", e)
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._profiler_running:
|
||||
try:
|
||||
self._cuda_profiler.stop()
|
||||
logger.info_once("Stopped CUDA profiler")
|
||||
except Exception as e:
|
||||
logger.warning_once("Failed to stop CUDA profiler: %s", e)
|
||||
finally:
|
||||
self._profiler_running = False
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Ensure profiler is stopped when shutting down."""
|
||||
self.stop()
|
||||
@@ -35,6 +35,7 @@ from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.models.interfaces import is_mixture_of_experts
|
||||
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.profiler.gpu_profiler import CudaProfilerWrapper
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
@@ -116,6 +117,8 @@ class Worker(WorkerBase):
|
||||
torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
|
||||
),
|
||||
)
|
||||
elif envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
self.profiler = CudaProfilerWrapper()
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
@@ -593,7 +596,10 @@ class Worker(WorkerBase):
|
||||
else:
|
||||
self.profiler.stop()
|
||||
# only print profiler results on rank 0
|
||||
if self.local_rank == 0:
|
||||
if (
|
||||
isinstance(self.profiler, torch.profiler.profile)
|
||||
and self.local_rank == 0
|
||||
):
|
||||
print(
|
||||
self.profiler.key_averages().table(sort_by="self_cuda_time_total")
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user