mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 15:04:47 +08:00
[BugFix] Fix DBO assert assert B_block_table == B_q (#29933)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
split_attn_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
|
||||
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -294,8 +294,14 @@ def test_prefill_split_across_ubatches(
|
||||
qsl_np = common.query_start_loc_cpu.numpy()
|
||||
num_tokens = common.num_actual_tokens
|
||||
|
||||
ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point)
|
||||
assert len(ubatch_slices) == 2
|
||||
ubatch_slices, _ = maybe_create_ubatch_slices(
|
||||
True,
|
||||
num_scheduled_tokens,
|
||||
num_tokens,
|
||||
batch_spec.batch_size,
|
||||
split_point=split_point,
|
||||
)
|
||||
assert ubatch_slices is not None and len(ubatch_slices) == 2
|
||||
|
||||
first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
|
||||
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)
|
||||
|
||||
@@ -1258,7 +1258,7 @@ class EagleProposer:
|
||||
num_tokens_padded: int,
|
||||
) -> tuple[int, torch.Tensor]:
|
||||
# TODO(Flechman): support DBO ubatching
|
||||
ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp(
|
||||
should_ubatch, num_toks_across_dp = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens_unpadded,
|
||||
parallel_config=self.vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
@@ -1267,7 +1267,7 @@ class EagleProposer:
|
||||
uniform_decode=None,
|
||||
num_scheduled_tokens_per_request=None,
|
||||
)
|
||||
assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE"
|
||||
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
|
||||
|
||||
num_tokens_dp_padded = num_tokens_padded
|
||||
if num_toks_across_dp is not None:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -9,10 +10,7 @@ from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
UBatchSlice,
|
||||
UBatchSlices,
|
||||
check_ubatch_thresholds,
|
||||
create_ubatch_slices,
|
||||
is_second_ubatch_empty,
|
||||
)
|
||||
|
||||
@@ -91,20 +89,6 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
|
||||
return num_tokens_across_dp.cpu()
|
||||
|
||||
|
||||
# This just pads the second ubatch slice out to the total number of tokens
|
||||
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
|
||||
def _pad_out_ubatch_slice(
|
||||
ubatch_slices: UBatchSlices, num_total_tokens: int
|
||||
) -> UBatchSlices:
|
||||
padded_second_token_slice = slice(
|
||||
ubatch_slices[1].token_slice.start, num_total_tokens
|
||||
)
|
||||
ubatch_slices[1] = UBatchSlice(
|
||||
ubatch_slices[1].request_slice, padded_second_token_slice
|
||||
)
|
||||
return ubatch_slices
|
||||
|
||||
|
||||
def _synchronize_dp_ranks(
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
@@ -175,7 +159,7 @@ def coordinate_batch_across_dp(
|
||||
num_tokens_padded: int | None = None,
|
||||
uniform_decode: bool | None = None,
|
||||
num_scheduled_tokens_per_request: np.ndarray | None = None,
|
||||
) -> tuple[UBatchSlices | None, torch.Tensor | None]:
|
||||
) -> tuple[bool, torch.Tensor | None]:
|
||||
"""
|
||||
Coordinates amongst all DP ranks to determine if and how the full batch
|
||||
should be split into microbatches.
|
||||
@@ -204,7 +188,7 @@ def coordinate_batch_across_dp(
|
||||
"""
|
||||
if parallel_config.data_parallel_size == 1:
|
||||
# Early exit.
|
||||
return None, None
|
||||
return False, None
|
||||
|
||||
# If the caller has explicitly enabled microbatching.
|
||||
should_attempt_ubatching = False
|
||||
@@ -228,23 +212,4 @@ def coordinate_batch_across_dp(
|
||||
parallel_config,
|
||||
)
|
||||
|
||||
# Don't microbatch unless every other DP worker is also microbatching
|
||||
if not should_ubatch:
|
||||
return (None, num_tokens_after_padding)
|
||||
|
||||
# This doesn't actually pad the ubatch slices. It just initializes the
|
||||
# split point to the padded value so that padding can be applied
|
||||
# to the second ubatch in pad_out_ubatch_slice after attention
|
||||
# metadata creation
|
||||
assert num_tokens_after_padding is not None
|
||||
num_tokens_padded = int(num_tokens_after_padding[0].item())
|
||||
token_split_point = int(num_tokens_padded) // 2
|
||||
|
||||
assert num_scheduled_tokens_per_request is not None
|
||||
ubatch_slices = create_ubatch_slices(
|
||||
num_scheduled_tokens_per_request, token_split_point
|
||||
)
|
||||
ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded)
|
||||
assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded
|
||||
|
||||
return (ubatch_slices, num_tokens_after_padding)
|
||||
return (should_ubatch, num_tokens_after_padding)
|
||||
|
||||
@@ -153,6 +153,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
UBatchSlices,
|
||||
check_ubatch_thresholds,
|
||||
maybe_create_ubatch_slices,
|
||||
)
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
|
||||
@@ -2743,7 +2744,7 @@ class GPUModelRunner(
|
||||
) -> tuple[
|
||||
CUDAGraphMode,
|
||||
BatchDescriptor,
|
||||
UBatchSlices | None,
|
||||
bool,
|
||||
torch.Tensor | None,
|
||||
CUDAGraphStat | None,
|
||||
]:
|
||||
@@ -2779,7 +2780,7 @@ class GPUModelRunner(
|
||||
|
||||
# Extra coordination when running data-parallel since we need to coordinate
|
||||
# across ranks
|
||||
ubatch_slices, num_tokens_across_dp = None, None
|
||||
should_ubatch, num_tokens_across_dp = False, None
|
||||
if self.vllm_config.parallel_config.data_parallel_size > 1:
|
||||
# Disable DP padding when running eager to avoid excessive padding when
|
||||
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
|
||||
@@ -2789,8 +2790,8 @@ class GPUModelRunner(
|
||||
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
)
|
||||
|
||||
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens_padded,
|
||||
should_ubatch, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
parallel_config=self.parallel_config,
|
||||
allow_microbatching=allow_microbatching,
|
||||
allow_dp_padding=allow_dp_padding,
|
||||
@@ -2822,7 +2823,7 @@ class GPUModelRunner(
|
||||
return (
|
||||
cudagraph_mode,
|
||||
batch_descriptor,
|
||||
ubatch_slices,
|
||||
should_ubatch,
|
||||
num_tokens_across_dp,
|
||||
cudagraph_stats,
|
||||
)
|
||||
@@ -2921,7 +2922,7 @@ class GPUModelRunner(
|
||||
(
|
||||
cudagraph_mode,
|
||||
batch_desc,
|
||||
ubatch_slices,
|
||||
should_ubatch,
|
||||
num_tokens_across_dp,
|
||||
cudagraph_stats,
|
||||
) = self._determine_batch_execution_and_padding(
|
||||
@@ -2934,10 +2935,10 @@ class GPUModelRunner(
|
||||
|
||||
logger.debug(
|
||||
"Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
|
||||
"ubatch_slices: %s, num_tokens_across_dp: %s",
|
||||
"should_ubatch: %s, num_tokens_across_dp: %s",
|
||||
cudagraph_mode,
|
||||
batch_desc,
|
||||
ubatch_slices,
|
||||
should_ubatch,
|
||||
num_tokens_across_dp,
|
||||
)
|
||||
|
||||
@@ -2945,9 +2946,17 @@ class GPUModelRunner(
|
||||
num_reqs_padded = (
|
||||
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
|
||||
)
|
||||
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
|
||||
should_ubatch,
|
||||
num_scheduled_tokens_np,
|
||||
num_tokens_padded,
|
||||
num_reqs_padded,
|
||||
)
|
||||
|
||||
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
|
||||
|
||||
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
|
||||
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
|
||||
|
||||
(attn_metadata, spec_decode_common_attn_metadata) = (
|
||||
self._build_attention_metadata(
|
||||
@@ -2956,7 +2965,7 @@ class GPUModelRunner(
|
||||
num_reqs=num_reqs,
|
||||
num_reqs_padded=num_reqs_padded if pad_attn else None,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
ubatch_slices=ubatch_slices,
|
||||
ubatch_slices=ubatch_slices_attn,
|
||||
logits_indices=logits_indices,
|
||||
use_spec_decode=use_spec_decode,
|
||||
num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
|
||||
@@ -2993,7 +3002,7 @@ class GPUModelRunner(
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_mode,
|
||||
batch_descriptor=batch_desc,
|
||||
ubatch_slices=ubatch_slices,
|
||||
ubatch_slices=ubatch_slices_padded,
|
||||
),
|
||||
record_function_or_nullcontext("gpu_model_runner: forward"),
|
||||
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
||||
@@ -3945,7 +3954,7 @@ class GPUModelRunner(
|
||||
|
||||
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
||||
|
||||
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp, _ = (
|
||||
_cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = (
|
||||
self._determine_batch_execution_and_padding(
|
||||
num_tokens=num_tokens_unpadded,
|
||||
num_reqs=num_reqs,
|
||||
@@ -3979,6 +3988,9 @@ class GPUModelRunner(
|
||||
num_reqs_padded = (
|
||||
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
|
||||
)
|
||||
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
|
||||
should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded
|
||||
)
|
||||
|
||||
attn_metadata: PerLayerAttnMetadata | None = None
|
||||
|
||||
@@ -4000,11 +4012,12 @@ class GPUModelRunner(
|
||||
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
|
||||
pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
|
||||
attn_metadata, _ = self._build_attention_metadata(
|
||||
num_tokens=num_tokens_unpadded,
|
||||
num_reqs=num_reqs_padded,
|
||||
max_query_len=max_query_len,
|
||||
ubatch_slices=ubatch_slices,
|
||||
ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices,
|
||||
for_cudagraph_capture=is_graph_capturing,
|
||||
)
|
||||
|
||||
@@ -4056,11 +4069,11 @@ class GPUModelRunner(
|
||||
num_tokens_padded, None, False
|
||||
)
|
||||
|
||||
if ubatch_slices is not None:
|
||||
if ubatch_slices_padded is not None:
|
||||
# Adjust values to reflect a single ubatch.
|
||||
# TODO(sage,lucas): this is cruft that should be addressed in
|
||||
# the padding refactor.
|
||||
num_tokens_padded = ubatch_slices[0].num_tokens
|
||||
num_tokens_padded = ubatch_slices_padded[0].num_tokens
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[:] = num_tokens_padded
|
||||
|
||||
@@ -4073,7 +4086,7 @@ class GPUModelRunner(
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_desc,
|
||||
ubatch_slices=ubatch_slices,
|
||||
ubatch_slices=ubatch_slices_padded,
|
||||
),
|
||||
):
|
||||
outputs = self.model(
|
||||
|
||||
@@ -42,9 +42,37 @@ def check_ubatch_thresholds(
|
||||
return num_tokens >= config.dbo_prefill_token_threshold
|
||||
|
||||
|
||||
def create_ubatch_slices(
|
||||
num_scheduled_tokens: np.ndarray, split_point: int
|
||||
# This just pads the second ubatch slice out to the total number of tokens
|
||||
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
|
||||
def _pad_out_ubatch_slices(
|
||||
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
|
||||
) -> UBatchSlices:
|
||||
# TODO(lucas): handle empty second ubatch
|
||||
padded_second_request_slice = slice(
|
||||
ubatch_slices[1].request_slice.start, num_reqs_padded
|
||||
)
|
||||
padded_second_token_slice = slice(
|
||||
ubatch_slices[1].token_slice.start, num_total_tokens
|
||||
)
|
||||
return [
|
||||
ubatch_slices[0],
|
||||
UBatchSlice(padded_second_request_slice, padded_second_token_slice),
|
||||
]
|
||||
|
||||
|
||||
def maybe_create_ubatch_slices(
|
||||
should_ubatch: bool,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_tokens_padded: int,
|
||||
num_reqs_padded: int,
|
||||
split_point: int | None = None,
|
||||
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
|
||||
if not should_ubatch:
|
||||
return None, None
|
||||
|
||||
if split_point is None:
|
||||
split_point = int(num_tokens_padded) // 2
|
||||
|
||||
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
||||
# in cu_num_tokens directly (i.e. query_start_loc)
|
||||
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
||||
@@ -67,7 +95,15 @@ def create_ubatch_slices(
|
||||
)
|
||||
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
|
||||
|
||||
return [
|
||||
ubatch_slices = [
|
||||
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
|
||||
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
|
||||
]
|
||||
|
||||
ubatch_slices_padded = _pad_out_ubatch_slices(
|
||||
ubatch_slices, num_tokens_padded, num_reqs_padded
|
||||
)
|
||||
|
||||
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
|
||||
|
||||
return ubatch_slices, ubatch_slices_padded
|
||||
|
||||
Reference in New Issue
Block a user