[BugFix] Fix DBO assert assert B_block_table == B_q (#29933)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-12-04 14:48:54 -05:00
committed by GitHub
parent 48a5fff66e
commit c8ab988b15
5 changed files with 83 additions and 63 deletions

View File

@@ -13,7 +13,7 @@ from vllm.v1.attention.backends.utils import (
split_attn_metadata,
split_decodes_and_prefills,
)
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
@pytest.fixture
@@ -294,8 +294,14 @@ def test_prefill_split_across_ubatches(
qsl_np = common.query_start_loc_cpu.numpy()
num_tokens = common.num_actual_tokens
ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point)
assert len(ubatch_slices) == 2
ubatch_slices, _ = maybe_create_ubatch_slices(
True,
num_scheduled_tokens,
num_tokens,
batch_spec.batch_size,
split_point=split_point,
)
assert ubatch_slices is not None and len(ubatch_slices) == 2
first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)

View File

@@ -1258,7 +1258,7 @@ class EagleProposer:
num_tokens_padded: int,
) -> tuple[int, torch.Tensor]:
# TODO(Flechman): support DBO ubatching
ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp(
should_ubatch, num_toks_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
@@ -1267,7 +1267,7 @@ class EagleProposer:
uniform_decode=None,
num_scheduled_tokens_per_request=None,
)
assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE"
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
num_tokens_dp_padded = num_tokens_padded
if num_toks_across_dp is not None:

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
import torch.distributed as dist
@@ -9,10 +10,7 @@ from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import (
UBatchSlice,
UBatchSlices,
check_ubatch_thresholds,
create_ubatch_slices,
is_second_ubatch_empty,
)
@@ -91,20 +89,6 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
return num_tokens_across_dp.cpu()
# This just pads the second ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def _pad_out_ubatch_slice(
ubatch_slices: UBatchSlices, num_total_tokens: int
) -> UBatchSlices:
padded_second_token_slice = slice(
ubatch_slices[1].token_slice.start, num_total_tokens
)
ubatch_slices[1] = UBatchSlice(
ubatch_slices[1].request_slice, padded_second_token_slice
)
return ubatch_slices
def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
@@ -175,7 +159,7 @@ def coordinate_batch_across_dp(
num_tokens_padded: int | None = None,
uniform_decode: bool | None = None,
num_scheduled_tokens_per_request: np.ndarray | None = None,
) -> tuple[UBatchSlices | None, torch.Tensor | None]:
) -> tuple[bool, torch.Tensor | None]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
should be split into microbatches.
@@ -204,7 +188,7 @@ def coordinate_batch_across_dp(
"""
if parallel_config.data_parallel_size == 1:
# Early exit.
return None, None
return False, None
# If the caller has explicitly enabled microbatching.
should_attempt_ubatching = False
@@ -228,23 +212,4 @@ def coordinate_batch_across_dp(
parallel_config,
)
# Don't microbatch unless every other DP worker is also microbatching
if not should_ubatch:
return (None, num_tokens_after_padding)
# This doesn't actually pad the ubatch slices. It just initializes the
# split point to the padded value so that padding can be applied
# to the second ubatch in pad_out_ubatch_slice after attention
# metadata creation
assert num_tokens_after_padding is not None
num_tokens_padded = int(num_tokens_after_padding[0].item())
token_split_point = int(num_tokens_padded) // 2
assert num_scheduled_tokens_per_request is not None
ubatch_slices = create_ubatch_slices(
num_scheduled_tokens_per_request, token_split_point
)
ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded)
assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded
return (ubatch_slices, num_tokens_after_padding)
return (should_ubatch, num_tokens_after_padding)

View File

@@ -153,6 +153,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.ubatch_utils import (
UBatchSlices,
check_ubatch_thresholds,
maybe_create_ubatch_slices,
)
from vllm.v1.worker.utils import is_residual_scattered_for_sp
@@ -2743,7 +2744,7 @@ class GPUModelRunner(
) -> tuple[
CUDAGraphMode,
BatchDescriptor,
UBatchSlices | None,
bool,
torch.Tensor | None,
CUDAGraphStat | None,
]:
@@ -2779,7 +2780,7 @@ class GPUModelRunner(
# Extra coordination when running data-parallel since we need to coordinate
# across ranks
ubatch_slices, num_tokens_across_dp = None, None
should_ubatch, num_tokens_across_dp = False, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
# Disable DP padding when running eager to avoid excessive padding when
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
@@ -2789,8 +2790,8 @@ class GPUModelRunner(
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
)
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_padded,
should_ubatch, num_tokens_across_dp = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.parallel_config,
allow_microbatching=allow_microbatching,
allow_dp_padding=allow_dp_padding,
@@ -2822,7 +2823,7 @@ class GPUModelRunner(
return (
cudagraph_mode,
batch_descriptor,
ubatch_slices,
should_ubatch,
num_tokens_across_dp,
cudagraph_stats,
)
@@ -2921,7 +2922,7 @@ class GPUModelRunner(
(
cudagraph_mode,
batch_desc,
ubatch_slices,
should_ubatch,
num_tokens_across_dp,
cudagraph_stats,
) = self._determine_batch_execution_and_padding(
@@ -2934,10 +2935,10 @@ class GPUModelRunner(
logger.debug(
"Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
"ubatch_slices: %s, num_tokens_across_dp: %s",
"should_ubatch: %s, num_tokens_across_dp: %s",
cudagraph_mode,
batch_desc,
ubatch_slices,
should_ubatch,
num_tokens_across_dp,
)
@@ -2945,9 +2946,17 @@ class GPUModelRunner(
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
should_ubatch,
num_scheduled_tokens_np,
num_tokens_padded,
num_reqs_padded,
)
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
(attn_metadata, spec_decode_common_attn_metadata) = (
self._build_attention_metadata(
@@ -2956,7 +2965,7 @@ class GPUModelRunner(
num_reqs=num_reqs,
num_reqs_padded=num_reqs_padded if pad_attn else None,
max_query_len=max_num_scheduled_tokens,
ubatch_slices=ubatch_slices,
ubatch_slices=ubatch_slices_attn,
logits_indices=logits_indices,
use_spec_decode=use_spec_decode,
num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
@@ -2993,7 +3002,7 @@ class GPUModelRunner(
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_mode,
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices,
ubatch_slices=ubatch_slices_padded,
),
record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
@@ -3945,7 +3954,7 @@ class GPUModelRunner(
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp, _ = (
_cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = (
self._determine_batch_execution_and_padding(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs,
@@ -3979,6 +3988,9 @@ class GPUModelRunner(
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded
)
attn_metadata: PerLayerAttnMetadata | None = None
@@ -4000,11 +4012,12 @@ class GPUModelRunner(
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
self.query_start_loc.copy_to_gpu()
pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
attn_metadata, _ = self._build_attention_metadata(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs_padded,
max_query_len=max_query_len,
ubatch_slices=ubatch_slices,
ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices,
for_cudagraph_capture=is_graph_capturing,
)
@@ -4056,11 +4069,11 @@ class GPUModelRunner(
num_tokens_padded, None, False
)
if ubatch_slices is not None:
if ubatch_slices_padded is not None:
# Adjust values to reflect a single ubatch.
# TODO(sage,lucas): this is cruft that should be addressed in
# the padding refactor.
num_tokens_padded = ubatch_slices[0].num_tokens
num_tokens_padded = ubatch_slices_padded[0].num_tokens
if num_tokens_across_dp is not None:
num_tokens_across_dp[:] = num_tokens_padded
@@ -4073,7 +4086,7 @@ class GPUModelRunner(
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices,
ubatch_slices=ubatch_slices_padded,
),
):
outputs = self.model(

View File

@@ -42,9 +42,37 @@ def check_ubatch_thresholds(
return num_tokens >= config.dbo_prefill_token_threshold
def create_ubatch_slices(
num_scheduled_tokens: np.ndarray, split_point: int
# This just pads the second ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def _pad_out_ubatch_slices(
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
) -> UBatchSlices:
# TODO(lucas): handle empty second ubatch
padded_second_request_slice = slice(
ubatch_slices[1].request_slice.start, num_reqs_padded
)
padded_second_token_slice = slice(
ubatch_slices[1].token_slice.start, num_total_tokens
)
return [
ubatch_slices[0],
UBatchSlice(padded_second_request_slice, padded_second_token_slice),
]
def maybe_create_ubatch_slices(
should_ubatch: bool,
num_scheduled_tokens: np.ndarray,
num_tokens_padded: int,
num_reqs_padded: int,
split_point: int | None = None,
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
if not should_ubatch:
return None, None
if split_point is None:
split_point = int(num_tokens_padded) // 2
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# in cu_num_tokens directly (i.e. query_start_loc)
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
@@ -67,7 +95,15 @@ def create_ubatch_slices(
)
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
return [
ubatch_slices = [
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
]
ubatch_slices_padded = _pad_out_ubatch_slices(
ubatch_slices, num_tokens_padded, num_reqs_padded
)
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
return ubatch_slices, ubatch_slices_padded