[Core] Add xxHash as a high-performance hash option for accelerating prefix caching (#29163)

Signed-off-by: LuminolT <lumischen01@gmail.com>
Signed-off-by: Lumis Chen <lumischen01@gmail.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Lumis Chen
2025-12-04 00:06:57 +08:00
committed by GitHub
parent 5aa9b09040
commit 9bcf92295a
7 changed files with 332 additions and 8 deletions

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Micro benchmark comparing built-in hash(), SHA-256, and xxHash.
This focuses on a single test payload shaped like the prefix-cache hash input:
(32-byte bytes object, 32-int tuple)
Usage:
python benchmarks/hash_micro_benchmark.py --iterations 20000
"""
from __future__ import annotations
import argparse
import random
import statistics
import time
from collections.abc import Callable, Iterable
from vllm.utils.hashing import sha256, xxhash
def _generate_test_data(seed: int) -> tuple[bytes, tuple[int, ...]]:
"""Generate a deterministic test payload."""
random.seed(seed)
bytes_data = bytes(random.getrandbits(8) for _ in range(32))
int_tuple = tuple(random.randint(1, 1_000_000) for _ in range(32))
return (bytes_data, int_tuple)
def _benchmark_func(func: Callable[[tuple], object], data: tuple, iterations: int):
"""Return (avg_seconds, std_seconds) for hashing `data` `iterations` times."""
times: list[float] = []
# Warm-up to avoid first-run noise.
for _ in range(200):
func(data)
for _ in range(iterations):
start = time.perf_counter()
func(data)
end = time.perf_counter()
times.append(end - start)
avg = statistics.mean(times)
std = statistics.stdev(times) if len(times) > 1 else 0.0
return avg, std
def _run_benchmarks(
benchmarks: Iterable[tuple[str, Callable[[tuple], object]]],
data: tuple,
iterations: int,
):
"""Yield (name, avg, std) for each benchmark, skipping unavailable ones."""
for name, func in benchmarks:
try:
avg, std = _benchmark_func(func, data, iterations)
except ModuleNotFoundError as exc:
print(f"Skipping {name}: {exc}")
continue
yield name, avg, std
def builtin_hash(data: tuple) -> int:
"""Wrapper for Python's built-in hash()."""
return hash(data)
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--iterations",
type=int,
default=10_000,
help="Number of measured iterations per hash function.",
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for test payload."
)
args = parser.parse_args()
data = _generate_test_data(args.seed)
benchmarks = (
("SHA256 (pickle)", sha256),
("xxHash (pickle)", xxhash),
("built-in hash()", builtin_hash),
)
print("=" * 60)
print("HASH FUNCTION MICRO BENCHMARK")
print("=" * 60)
print("Test data: (32-byte bytes object, 32-int tuple)")
print(f"Iterations: {args.iterations:,}")
print("=" * 60)
results = list(_run_benchmarks(benchmarks, data, args.iterations))
builtin_entry = next((r for r in results if r[0] == "built-in hash()"), None)
print("\nResults:")
for name, avg, std in results:
print(f" {name:16s}: {avg * 1e6:8.2f} ± {std * 1e6:6.2f} μs")
if builtin_entry:
_, builtin_avg, _ = builtin_entry
print("\n" + "=" * 60)
print("SUMMARY (relative to built-in hash())")
print("=" * 60)
for name, avg, _ in results:
if name == "built-in hash()":
continue
speed_ratio = avg / builtin_avg
print(f"{name} is {speed_ratio:.1f}x slower than built-in hash()")
else:
print("\nBuilt-in hash() result missing; cannot compute speed ratios.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,110 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Simple benchmark to compare prefix-cache block hashing algorithms.
Example:
python benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32
"""
from __future__ import annotations
import argparse
import random
import statistics
import sys
import time
from collections.abc import Callable, Iterable, Sequence
from vllm.utils.hashing import get_hash_fn_by_name
from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash
SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor")
def _generate_blocks(
num_blocks: int, block_size: int, vocab_size: int, seed: int
) -> list[list[int]]:
rng = random.Random(seed)
return [
[rng.randrange(vocab_size) for _ in range(block_size)]
for _ in range(num_blocks)
]
def _hash_all_blocks(
hash_fn: Callable[[object], bytes],
blocks: Iterable[Sequence[int]],
) -> float:
parent_hash: BlockHash | None = None
start = time.perf_counter()
for block in blocks:
parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None)
end = time.perf_counter()
return end - start
def _benchmark(
hash_algo: str,
blocks: list[list[int]],
trials: int,
) -> tuple[float, float, float] | None:
try:
hash_fn = get_hash_fn_by_name(hash_algo)
init_none_hash(hash_fn)
timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)]
except ModuleNotFoundError as exc:
print(f"Skipping {hash_algo}: {exc}", file=sys.stderr)
return None
avg = statistics.mean(timings)
best = min(timings)
# throughput: tokens / second
tokens_hashed = len(blocks) * len(blocks[0])
throughput = tokens_hashed / best
return avg, best, throughput
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.")
parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.")
parser.add_argument(
"--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)."
)
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
parser.add_argument(
"--trials", type=int, default=5, help="Number of timed trials per algorithm."
)
parser.add_argument(
"--algorithms",
nargs="+",
default=SUPPORTED_ALGOS,
choices=SUPPORTED_ALGOS,
help="Hash algorithms to benchmark.",
)
args = parser.parse_args()
blocks = _generate_blocks(
args.num_blocks, args.block_size, args.vocab_size, args.seed
)
print(
f"Benchmarking {len(args.algorithms)} algorithms on "
f"{args.num_blocks} blocks (block size={args.block_size})."
)
for algo in args.algorithms:
result = _benchmark(algo, blocks, args.trials)
if result is None:
continue
avg, best, throughput = result
print(
f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s "
f"throughput: {throughput / 1e6:.2f}M tokens/s"
)
if __name__ == "__main__":
main()

View File

@@ -670,6 +670,35 @@ vllm bench serve \
</details>
### 🧪 Hashing Benchmarks
<details class="admonition abstract" markdown="1">
<summary>Show more</summary>
Two helper scripts live in `benchmarks/` to compare hashing options used by prefix caching and related utilities. They are standalone (no server required) and help choose a hash algorithm before enabling prefix caching in production.
- `benchmarks/benchmark_hash.py`: Micro-benchmark that measures per-call latency of three implementations on a representative `(bytes, tuple[int])` payload.
```bash
python benchmarks/benchmark_hash.py --iterations 20000 --seed 42
```
- `benchmarks/benchmark_prefix_block_hash.py`: End-to-end block hashing benchmark that runs the full prefix-cache hash pipeline (`hash_block_tokens`) across many fake blocks and reports throughput.
```bash
python benchmarks/benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 --trials 5
```
Supported algorithms: `sha256`, `sha256_cbor`, `xxhash`, `xxhash_cbor`. Install optional deps to exercise all variants:
```bash
uv pip install xxhash cbor2
```
If an algorithms dependency is missing, the script will skip it and continue.
</details>
### ⚡ Request Prioritization Benchmark
<details class="admonition abstract" markdown="1">

View File

@@ -9,6 +9,7 @@ from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.hashing import _xxhash
def test_prefix_caching_from_cli():
@@ -48,6 +49,21 @@ def test_prefix_caching_from_cli():
args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])
@pytest.mark.skipif(_xxhash is None, reason="xxhash not installed")
def test_prefix_caching_xxhash_from_cli():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
# set hash algorithm to xxhash (pickle)
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash"
# set hash algorithm to xxhash_cbor
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash_cbor"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash_cbor"
def test_defaults_with_usage_context():
engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS)

View File

@@ -30,7 +30,7 @@ CacheDType = Literal[
"fp8_ds_mla",
]
MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"]
@@ -77,9 +77,21 @@ class CacheConfig:
"""Whether to enable prefix caching."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
"""Set the hash algorithm for prefix caching:\n
- "sha256" uses Pickle for object serialization before hashing.\n
- "sha256" uses Pickle for object serialization before hashing. This is the
current default, as SHA256 is the most secure choice to avoid potential
hash collisions.\n
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
serializes objects using canonical CBOR and hashes them with SHA-256."""
serializes objects using canonical CBOR and hashes them with SHA-256.\n
- "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
non-cryptographic hashing. Requires the optional ``xxhash`` package.
IMPORTANT: Use of a hashing algorithm that is not considered
cryptographically secure theoretically increases the risk of hash collisions,
which can cause undefined behavior or even leak private information in
multi-tenant environments. Even if collisions are still very unlikely, it is
important to consider your security risk tolerance against the performance
benefits before turning this on.\n
- "xxhash_cbor" combines canonical CBOR serialization with xxHash for
reproducible hashing. Requires the optional ``xxhash`` package."""
cpu_offload_gb: float = Field(default=0, ge=0)
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to

View File

@@ -11,6 +11,17 @@ from typing import Any
import cbor2
try:
# It is important that this remains an optional dependency.
# It would not be allowed in environments with strict security controls,
# so it's best not to have it installed when not in use.
import xxhash as _xxhash
if not hasattr(_xxhash, "xxh3_128_digest"):
_xxhash = None
except ImportError: # pragma: no cover
_xxhash = None
def sha256(input: Any) -> bytes:
"""Hash any picklable Python object using SHA-256.
@@ -47,6 +58,27 @@ def sha256_cbor(input: Any) -> bytes:
return hashlib.sha256(input_bytes).digest()
def _xxhash_digest(input_bytes: bytes) -> bytes:
if _xxhash is None:
raise ModuleNotFoundError(
"xxhash is required for the 'xxhash' prefix caching hash algorithms. "
"Install it via `pip install xxhash`."
)
return _xxhash.xxh3_128_digest(input_bytes)
def xxhash(input: Any) -> bytes:
"""Hash picklable objects using xxHash."""
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return _xxhash_digest(input_bytes)
def xxhash_cbor(input: Any) -> bytes:
"""Hash objects serialized with CBOR using xxHash."""
input_bytes = cbor2.dumps(input, canonical=True)
return _xxhash_digest(input_bytes)
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
"""Get a hash function by name, or raise an error if the function is not found.
@@ -60,6 +92,10 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
return sha256
if hash_fn_name == "sha256_cbor":
return sha256_cbor
if hash_fn_name == "xxhash":
return xxhash
if hash_fn_name == "xxhash_cbor":
return xxhash_cbor
raise ValueError(f"Unsupported hash function: {hash_fn_name}")

View File

@@ -12,7 +12,7 @@ from typing import Any, NewType, TypeAlias, overload
from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.hashing import sha256_cbor
from vllm.utils.hashing import sha256_cbor, xxhash_cbor
from vllm.utils.math_utils import cdiv
from vllm.utils.mem_constants import GiB_bytes
from vllm.v1.kv_cache_interface import (
@@ -83,18 +83,19 @@ logger = init_logger(__name__)
#
# The function `init_none_hash` initializes this variable globally.
NONE_HASH: BlockHash
_CBOR_HASH_FUNCTIONS = frozenset({sha256_cbor, xxhash_cbor})
def init_none_hash(hash_fn: Callable[[Any], bytes]):
global NONE_HASH
hash_seed = os.getenv("PYTHONHASHSEED")
if hash_seed is None and hash_fn is sha256_cbor:
if hash_seed is None and hash_fn in _CBOR_HASH_FUNCTIONS:
logger.warning(
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
"block-hashes when using sha256_cbor as the hash function."
"Consider setting PYTHONHASHSEED to a fixed value for "
"reproducibility."
"block-hashes when using CBOR-based hash functions such as "
"sha256_cbor or xxhash_cbor. Consider setting PYTHONHASHSEED to a "
"fixed value for reproducibility."
)
if hash_seed is None: