mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[Core] Add xxHash as a high-performance hash option for accelerating prefix caching (#29163)
Signed-off-by: LuminolT <lumischen01@gmail.com> Signed-off-by: Lumis Chen <lumischen01@gmail.com> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
120
benchmarks/benchmark_hash.py
Normal file
120
benchmarks/benchmark_hash.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Micro benchmark comparing built-in hash(), SHA-256, and xxHash.
|
||||
|
||||
This focuses on a single test payload shaped like the prefix-cache hash input:
|
||||
(32-byte bytes object, 32-int tuple)
|
||||
|
||||
Usage:
|
||||
python benchmarks/hash_micro_benchmark.py --iterations 20000
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import statistics
|
||||
import time
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
from vllm.utils.hashing import sha256, xxhash
|
||||
|
||||
|
||||
def _generate_test_data(seed: int) -> tuple[bytes, tuple[int, ...]]:
|
||||
"""Generate a deterministic test payload."""
|
||||
random.seed(seed)
|
||||
bytes_data = bytes(random.getrandbits(8) for _ in range(32))
|
||||
int_tuple = tuple(random.randint(1, 1_000_000) for _ in range(32))
|
||||
return (bytes_data, int_tuple)
|
||||
|
||||
|
||||
def _benchmark_func(func: Callable[[tuple], object], data: tuple, iterations: int):
|
||||
"""Return (avg_seconds, std_seconds) for hashing `data` `iterations` times."""
|
||||
times: list[float] = []
|
||||
|
||||
# Warm-up to avoid first-run noise.
|
||||
for _ in range(200):
|
||||
func(data)
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
func(data)
|
||||
end = time.perf_counter()
|
||||
times.append(end - start)
|
||||
|
||||
avg = statistics.mean(times)
|
||||
std = statistics.stdev(times) if len(times) > 1 else 0.0
|
||||
return avg, std
|
||||
|
||||
|
||||
def _run_benchmarks(
|
||||
benchmarks: Iterable[tuple[str, Callable[[tuple], object]]],
|
||||
data: tuple,
|
||||
iterations: int,
|
||||
):
|
||||
"""Yield (name, avg, std) for each benchmark, skipping unavailable ones."""
|
||||
for name, func in benchmarks:
|
||||
try:
|
||||
avg, std = _benchmark_func(func, data, iterations)
|
||||
except ModuleNotFoundError as exc:
|
||||
print(f"Skipping {name}: {exc}")
|
||||
continue
|
||||
yield name, avg, std
|
||||
|
||||
|
||||
def builtin_hash(data: tuple) -> int:
|
||||
"""Wrapper for Python's built-in hash()."""
|
||||
return hash(data)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=10_000,
|
||||
help="Number of measured iterations per hash function.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for test payload."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
data = _generate_test_data(args.seed)
|
||||
benchmarks = (
|
||||
("SHA256 (pickle)", sha256),
|
||||
("xxHash (pickle)", xxhash),
|
||||
("built-in hash()", builtin_hash),
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
print("HASH FUNCTION MICRO BENCHMARK")
|
||||
print("=" * 60)
|
||||
print("Test data: (32-byte bytes object, 32-int tuple)")
|
||||
print(f"Iterations: {args.iterations:,}")
|
||||
print("=" * 60)
|
||||
|
||||
results = list(_run_benchmarks(benchmarks, data, args.iterations))
|
||||
builtin_entry = next((r for r in results if r[0] == "built-in hash()"), None)
|
||||
|
||||
print("\nResults:")
|
||||
for name, avg, std in results:
|
||||
print(f" {name:16s}: {avg * 1e6:8.2f} ± {std * 1e6:6.2f} μs")
|
||||
|
||||
if builtin_entry:
|
||||
_, builtin_avg, _ = builtin_entry
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY (relative to built-in hash())")
|
||||
print("=" * 60)
|
||||
for name, avg, _ in results:
|
||||
if name == "built-in hash()":
|
||||
continue
|
||||
speed_ratio = avg / builtin_avg
|
||||
print(f"• {name} is {speed_ratio:.1f}x slower than built-in hash()")
|
||||
else:
|
||||
print("\nBuilt-in hash() result missing; cannot compute speed ratios.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
110
benchmarks/benchmark_prefix_block_hash.py
Normal file
110
benchmarks/benchmark_prefix_block_hash.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Simple benchmark to compare prefix-cache block hashing algorithms.
|
||||
|
||||
Example:
|
||||
python benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
|
||||
from vllm.utils.hashing import get_hash_fn_by_name
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash
|
||||
|
||||
SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor")
|
||||
|
||||
|
||||
def _generate_blocks(
|
||||
num_blocks: int, block_size: int, vocab_size: int, seed: int
|
||||
) -> list[list[int]]:
|
||||
rng = random.Random(seed)
|
||||
return [
|
||||
[rng.randrange(vocab_size) for _ in range(block_size)]
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
|
||||
|
||||
def _hash_all_blocks(
|
||||
hash_fn: Callable[[object], bytes],
|
||||
blocks: Iterable[Sequence[int]],
|
||||
) -> float:
|
||||
parent_hash: BlockHash | None = None
|
||||
start = time.perf_counter()
|
||||
for block in blocks:
|
||||
parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None)
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def _benchmark(
|
||||
hash_algo: str,
|
||||
blocks: list[list[int]],
|
||||
trials: int,
|
||||
) -> tuple[float, float, float] | None:
|
||||
try:
|
||||
hash_fn = get_hash_fn_by_name(hash_algo)
|
||||
init_none_hash(hash_fn)
|
||||
timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)]
|
||||
except ModuleNotFoundError as exc:
|
||||
print(f"Skipping {hash_algo}: {exc}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
avg = statistics.mean(timings)
|
||||
best = min(timings)
|
||||
# throughput: tokens / second
|
||||
tokens_hashed = len(blocks) * len(blocks[0])
|
||||
throughput = tokens_hashed / best
|
||||
return avg, best, throughput
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.")
|
||||
parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.")
|
||||
parser.add_argument(
|
||||
"--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)."
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
|
||||
parser.add_argument(
|
||||
"--trials", type=int, default=5, help="Number of timed trials per algorithm."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--algorithms",
|
||||
nargs="+",
|
||||
default=SUPPORTED_ALGOS,
|
||||
choices=SUPPORTED_ALGOS,
|
||||
help="Hash algorithms to benchmark.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
blocks = _generate_blocks(
|
||||
args.num_blocks, args.block_size, args.vocab_size, args.seed
|
||||
)
|
||||
print(
|
||||
f"Benchmarking {len(args.algorithms)} algorithms on "
|
||||
f"{args.num_blocks} blocks (block size={args.block_size})."
|
||||
)
|
||||
|
||||
for algo in args.algorithms:
|
||||
result = _benchmark(algo, blocks, args.trials)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
avg, best, throughput = result
|
||||
print(
|
||||
f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s "
|
||||
f"throughput: {throughput / 1e6:.2f}M tokens/s"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -670,6 +670,35 @@ vllm bench serve \
|
||||
|
||||
</details>
|
||||
|
||||
### 🧪 Hashing Benchmarks
|
||||
|
||||
<details class="admonition abstract" markdown="1">
|
||||
<summary>Show more</summary>
|
||||
|
||||
Two helper scripts live in `benchmarks/` to compare hashing options used by prefix caching and related utilities. They are standalone (no server required) and help choose a hash algorithm before enabling prefix caching in production.
|
||||
|
||||
- `benchmarks/benchmark_hash.py`: Micro-benchmark that measures per-call latency of three implementations on a representative `(bytes, tuple[int])` payload.
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_hash.py --iterations 20000 --seed 42
|
||||
```
|
||||
|
||||
- `benchmarks/benchmark_prefix_block_hash.py`: End-to-end block hashing benchmark that runs the full prefix-cache hash pipeline (`hash_block_tokens`) across many fake blocks and reports throughput.
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 --trials 5
|
||||
```
|
||||
|
||||
Supported algorithms: `sha256`, `sha256_cbor`, `xxhash`, `xxhash_cbor`. Install optional deps to exercise all variants:
|
||||
|
||||
```bash
|
||||
uv pip install xxhash cbor2
|
||||
```
|
||||
|
||||
If an algorithm’s dependency is missing, the script will skip it and continue.
|
||||
|
||||
</details>
|
||||
|
||||
### ⚡ Request Prioritization Benchmark
|
||||
|
||||
<details class="admonition abstract" markdown="1">
|
||||
|
||||
@@ -9,6 +9,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.hashing import _xxhash
|
||||
|
||||
|
||||
def test_prefix_caching_from_cli():
|
||||
@@ -48,6 +49,21 @@ def test_prefix_caching_from_cli():
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(_xxhash is None, reason="xxhash not installed")
|
||||
def test_prefix_caching_xxhash_from_cli():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
|
||||
# set hash algorithm to xxhash (pickle)
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash"
|
||||
|
||||
# set hash algorithm to xxhash_cbor
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash_cbor"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash_cbor"
|
||||
|
||||
|
||||
def test_defaults_with_usage_context():
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS)
|
||||
|
||||
@@ -30,7 +30,7 @@ CacheDType = Literal[
|
||||
"fp8_ds_mla",
|
||||
]
|
||||
MambaDType = Literal["auto", "float32"]
|
||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
|
||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
|
||||
KVOffloadingBackend = Literal["native", "lmcache"]
|
||||
|
||||
|
||||
@@ -77,9 +77,21 @@ class CacheConfig:
|
||||
"""Whether to enable prefix caching."""
|
||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
|
||||
"""Set the hash algorithm for prefix caching:\n
|
||||
- "sha256" uses Pickle for object serialization before hashing.\n
|
||||
- "sha256" uses Pickle for object serialization before hashing. This is the
|
||||
current default, as SHA256 is the most secure choice to avoid potential
|
||||
hash collisions.\n
|
||||
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
|
||||
serializes objects using canonical CBOR and hashes them with SHA-256."""
|
||||
serializes objects using canonical CBOR and hashes them with SHA-256.\n
|
||||
- "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
|
||||
non-cryptographic hashing. Requires the optional ``xxhash`` package.
|
||||
IMPORTANT: Use of a hashing algorithm that is not considered
|
||||
cryptographically secure theoretically increases the risk of hash collisions,
|
||||
which can cause undefined behavior or even leak private information in
|
||||
multi-tenant environments. Even if collisions are still very unlikely, it is
|
||||
important to consider your security risk tolerance against the performance
|
||||
benefits before turning this on.\n
|
||||
- "xxhash_cbor" combines canonical CBOR serialization with xxHash for
|
||||
reproducible hashing. Requires the optional ``xxhash`` package."""
|
||||
cpu_offload_gb: float = Field(default=0, ge=0)
|
||||
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
||||
no offloading. Intuitively, this argument can be seen as a virtual way to
|
||||
|
||||
@@ -11,6 +11,17 @@ from typing import Any
|
||||
|
||||
import cbor2
|
||||
|
||||
try:
|
||||
# It is important that this remains an optional dependency.
|
||||
# It would not be allowed in environments with strict security controls,
|
||||
# so it's best not to have it installed when not in use.
|
||||
import xxhash as _xxhash
|
||||
|
||||
if not hasattr(_xxhash, "xxh3_128_digest"):
|
||||
_xxhash = None
|
||||
except ImportError: # pragma: no cover
|
||||
_xxhash = None
|
||||
|
||||
|
||||
def sha256(input: Any) -> bytes:
|
||||
"""Hash any picklable Python object using SHA-256.
|
||||
@@ -47,6 +58,27 @@ def sha256_cbor(input: Any) -> bytes:
|
||||
return hashlib.sha256(input_bytes).digest()
|
||||
|
||||
|
||||
def _xxhash_digest(input_bytes: bytes) -> bytes:
|
||||
if _xxhash is None:
|
||||
raise ModuleNotFoundError(
|
||||
"xxhash is required for the 'xxhash' prefix caching hash algorithms. "
|
||||
"Install it via `pip install xxhash`."
|
||||
)
|
||||
return _xxhash.xxh3_128_digest(input_bytes)
|
||||
|
||||
|
||||
def xxhash(input: Any) -> bytes:
|
||||
"""Hash picklable objects using xxHash."""
|
||||
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
return _xxhash_digest(input_bytes)
|
||||
|
||||
|
||||
def xxhash_cbor(input: Any) -> bytes:
|
||||
"""Hash objects serialized with CBOR using xxHash."""
|
||||
input_bytes = cbor2.dumps(input, canonical=True)
|
||||
return _xxhash_digest(input_bytes)
|
||||
|
||||
|
||||
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
||||
"""Get a hash function by name, or raise an error if the function is not found.
|
||||
|
||||
@@ -60,6 +92,10 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
||||
return sha256
|
||||
if hash_fn_name == "sha256_cbor":
|
||||
return sha256_cbor
|
||||
if hash_fn_name == "xxhash":
|
||||
return xxhash
|
||||
if hash_fn_name == "xxhash_cbor":
|
||||
return xxhash_cbor
|
||||
|
||||
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import Any, NewType, TypeAlias, overload
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import sha256_cbor
|
||||
from vllm.utils.hashing import sha256_cbor, xxhash_cbor
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
@@ -83,18 +83,19 @@ logger = init_logger(__name__)
|
||||
#
|
||||
# The function `init_none_hash` initializes this variable globally.
|
||||
NONE_HASH: BlockHash
|
||||
_CBOR_HASH_FUNCTIONS = frozenset({sha256_cbor, xxhash_cbor})
|
||||
|
||||
|
||||
def init_none_hash(hash_fn: Callable[[Any], bytes]):
|
||||
global NONE_HASH
|
||||
|
||||
hash_seed = os.getenv("PYTHONHASHSEED")
|
||||
if hash_seed is None and hash_fn is sha256_cbor:
|
||||
if hash_seed is None and hash_fn in _CBOR_HASH_FUNCTIONS:
|
||||
logger.warning(
|
||||
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
|
||||
"block-hashes when using sha256_cbor as the hash function."
|
||||
"Consider setting PYTHONHASHSEED to a fixed value for "
|
||||
"reproducibility."
|
||||
"block-hashes when using CBOR-based hash functions such as "
|
||||
"sha256_cbor or xxhash_cbor. Consider setting PYTHONHASHSEED to a "
|
||||
"fixed value for reproducibility."
|
||||
)
|
||||
|
||||
if hash_seed is None:
|
||||
|
||||
Reference in New Issue
Block a user