mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[V1] [Hybrid] Mamba1 Automatic Prefix Caching (#26377)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
73444b7b56
commit
00b31a36a2
@@ -24,6 +24,8 @@ struct SSMParamsBase {
|
||||
int64_t pad_slot_id;
|
||||
|
||||
bool delta_softplus;
|
||||
bool cache_enabled;
|
||||
int block_size;
|
||||
|
||||
index_t A_d_stride;
|
||||
index_t A_dstate_stride;
|
||||
@@ -46,8 +48,9 @@ struct SSMParamsBase {
|
||||
index_t out_z_batch_stride;
|
||||
index_t out_z_d_stride;
|
||||
index_t ssm_states_batch_stride;
|
||||
index_t ssm_states_dim_stride;
|
||||
index_t ssm_states_dim_stride;
|
||||
index_t ssm_states_dstate_stride;
|
||||
index_t cache_indices_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ A_ptr;
|
||||
@@ -66,6 +69,9 @@ struct SSMParamsBase {
|
||||
void *__restrict__ cache_indices_ptr;
|
||||
void *__restrict__ has_initial_state_ptr;
|
||||
|
||||
void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write
|
||||
void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write
|
||||
void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use
|
||||
};
|
||||
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
|
||||
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
||||
if (cache_index == params.pad_slot_id){
|
||||
return;
|
||||
@@ -133,9 +133,18 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
|
||||
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
||||
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
|
||||
typename Ktraits::state_t *ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
|
||||
cache_index * params.ssm_states_batch_stride +
|
||||
dim_id * kNRows * params.ssm_states_dim_stride;
|
||||
|
||||
typename Ktraits::state_t *ssm_states;
|
||||
if (params.cache_enabled) {
|
||||
// APC mode: ssm_states points to the base, we'll use absolute cache slots later
|
||||
ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
|
||||
dim_id * kNRows * params.ssm_states_dim_stride;
|
||||
} else {
|
||||
// Non-APC mode: offset by cache_index as before
|
||||
ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
|
||||
cache_index * params.ssm_states_batch_stride +
|
||||
dim_id * kNRows * params.ssm_states_dim_stride;
|
||||
}
|
||||
|
||||
float D_val[kNRows] = {0};
|
||||
if (params.D_ptr != nullptr) {
|
||||
@@ -159,7 +168,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
// }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNItems;
|
||||
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
||||
|
||||
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
|
||||
const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048;
|
||||
const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size;
|
||||
|
||||
const int* batch_cache_indices = cache_indices != nullptr ?
|
||||
cache_indices + batch_id * params.cache_indices_stride : nullptr;
|
||||
const int* block_idx_first_scheduled = params.block_idx_first_scheduled_token_ptr != nullptr ?
|
||||
reinterpret_cast<const int*>(params.block_idx_first_scheduled_token_ptr) : nullptr;
|
||||
const int* block_idx_last_scheduled = params.block_idx_last_scheduled_token_ptr != nullptr ?
|
||||
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
|
||||
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
|
||||
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
|
||||
|
||||
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
|
||||
|
||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
||||
|
||||
@@ -219,7 +243,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
if constexpr (kIsVariableC) {
|
||||
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
||||
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
||||
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 ));
|
||||
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1));
|
||||
if constexpr (!kIsVariableB) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
@@ -242,7 +266,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
||||
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
||||
|
||||
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
|
||||
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
|
||||
thread_data[i] = make_float2(1.f, 0.f);
|
||||
@@ -250,8 +273,24 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
}
|
||||
}
|
||||
// Initialize running total
|
||||
|
||||
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0);
|
||||
scan_t running_prefix;
|
||||
if (chunk > 0) {
|
||||
running_prefix = smem_running_prefix[state_idx + r * MAX_DSTATE];
|
||||
} else {
|
||||
// Load initial state
|
||||
if (params.cache_enabled && has_initial_state && batch_cache_indices != nullptr) {
|
||||
size_t state_offset = load_cache_slot * params.ssm_states_batch_stride +
|
||||
r * params.ssm_states_dim_stride +
|
||||
state_idx * params.ssm_states_dstate_stride;
|
||||
running_prefix = make_float2(1.0, float(ssm_states[state_offset]));
|
||||
} else if (has_initial_state) {
|
||||
// Non-APC mode: load from current batch position
|
||||
running_prefix = make_float2(1.0, float(ssm_states[state_idx * params.ssm_states_dstate_stride]));
|
||||
} else {
|
||||
// No initial state
|
||||
running_prefix = make_float2(1.0, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
||||
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
||||
@@ -260,8 +299,25 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
// There's a syncthreads in the scan op, so we don't need to sync here.
|
||||
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
||||
if (threadIdx.x == 0) {
|
||||
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
||||
if (chunk == n_chunks - 1) {
|
||||
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
|
||||
|
||||
// Store state at the end of each chunk when cache is enabled
|
||||
if (params.cache_enabled && batch_cache_indices != nullptr) {
|
||||
|
||||
size_t cache_slot;
|
||||
if (chunk == n_chunks - 1) {
|
||||
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
|
||||
} else {
|
||||
cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk];
|
||||
}
|
||||
|
||||
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
|
||||
r * params.ssm_states_dim_stride +
|
||||
state_idx * params.ssm_states_dstate_stride;
|
||||
|
||||
ssm_states[state_offset] = typename Ktraits::state_t(prefix_op.running_prefix.y);
|
||||
} else if (!params.cache_enabled && chunk == n_chunks - 1) {
|
||||
// Non-APC mode: store only final state at current batch position
|
||||
ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y);
|
||||
}
|
||||
}
|
||||
@@ -274,7 +330,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
||||
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
||||
__syncthreads();
|
||||
@@ -346,7 +401,9 @@ template<typename input_t, typename weight_t, typename state_t>
|
||||
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
if (params.seqlen <= 128) {
|
||||
if (params.cache_enabled && params.block_size == 1024) {
|
||||
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
|
||||
} else if (params.seqlen <= 128) {
|
||||
selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream);
|
||||
} else if (params.seqlen <= 256) {
|
||||
selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream);
|
||||
@@ -358,7 +415,9 @@ void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
|
||||
}
|
||||
#else
|
||||
if (params.seqlen <= 256) {
|
||||
if (params.cache_enabled && params.block_size == 1024) {
|
||||
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
|
||||
} else if (params.seqlen <= 256) {
|
||||
selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream);
|
||||
} else if (params.seqlen <= 512) {
|
||||
selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream);
|
||||
@@ -437,13 +496,17 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
const std::optional<at::Tensor>& D,
|
||||
const std::optional<at::Tensor>& delta_bias,
|
||||
const torch::Tensor ssm_states,
|
||||
bool has_z,
|
||||
bool has_z,
|
||||
bool delta_softplus,
|
||||
const std::optional<at::Tensor>& query_start_loc,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool varlen,
|
||||
int64_t pad_slot_id) {
|
||||
int64_t pad_slot_id,
|
||||
int64_t block_size,
|
||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||
const std::optional<torch::Tensor> &initial_state_idx) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
@@ -477,6 +540,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
|
||||
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
|
||||
|
||||
// Set cache parameters - cache is enabled if we have direct cache writing params
|
||||
params.cache_enabled = block_idx_first_scheduled_token.has_value();
|
||||
params.block_size = static_cast<int>(block_size);
|
||||
|
||||
// Set direct cache writing pointers
|
||||
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
|
||||
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
|
||||
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
|
||||
|
||||
// All stride are in elements, not bytes.
|
||||
params.A_d_stride = A.stride(0);
|
||||
@@ -504,9 +575,11 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
params.out_d_stride = out.stride(0);
|
||||
|
||||
params.ssm_states_batch_stride = ssm_states.stride(0);
|
||||
params.ssm_states_dim_stride = ssm_states.stride(1);
|
||||
params.ssm_states_dim_stride = ssm_states.stride(1);
|
||||
params.ssm_states_dstate_stride = ssm_states.stride(2);
|
||||
|
||||
params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0;
|
||||
|
||||
}
|
||||
else{
|
||||
if (!is_variable_B) {
|
||||
@@ -537,8 +610,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
params.out_d_stride = out.stride(1);
|
||||
|
||||
params.ssm_states_batch_stride = ssm_states.stride(0);
|
||||
params.ssm_states_dim_stride = ssm_states.stride(1);
|
||||
params.ssm_states_dim_stride = ssm_states.stride(1);
|
||||
params.ssm_states_dstate_stride = ssm_states.stride(2);
|
||||
|
||||
params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -554,7 +629,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const torch::Tensor &ssm_states,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
int64_t pad_slot_id,
|
||||
int64_t block_size,
|
||||
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
||||
const std::optional<torch::Tensor> &initial_state_idx) {
|
||||
auto input_type = u.scalar_type();
|
||||
auto weight_type = A.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
@@ -646,7 +725,16 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
auto cache_indices_ = cache_indices.value();
|
||||
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(cache_indices_.is_cuda());
|
||||
CHECK_SHAPE(cache_indices_, batch_size);
|
||||
|
||||
// cache_indices can be either 1D (batch_size,) for non-APC mode
|
||||
// or 2D (batch_size, max_positions) for APC mode
|
||||
const bool is_apc_mode = block_idx_first_scheduled_token.has_value();
|
||||
if (is_apc_mode) {
|
||||
TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode");
|
||||
TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size");
|
||||
} else {
|
||||
CHECK_SHAPE(cache_indices_, batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -686,7 +774,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
varlen,
|
||||
pad_slot_id
|
||||
pad_slot_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx
|
||||
);
|
||||
|
||||
|
||||
|
||||
24
csrc/ops.h
24
csrc/ops.h
@@ -321,17 +321,19 @@ void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
|
||||
std::optional<torch::Tensor> const& scale_ub);
|
||||
|
||||
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||
const torch::Tensor& A, const torch::Tensor& B,
|
||||
const torch::Tensor& C,
|
||||
const std::optional<torch::Tensor>& D_,
|
||||
const std::optional<torch::Tensor>& z_,
|
||||
const std::optional<torch::Tensor>& delta_bias_,
|
||||
bool delta_softplus,
|
||||
const std::optional<torch::Tensor>& query_start_loc,
|
||||
const std::optional<torch::Tensor>& cache_indices,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||
void selective_scan_fwd(
|
||||
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
|
||||
const torch::Tensor& B, const torch::Tensor& C,
|
||||
const std::optional<torch::Tensor>& D_,
|
||||
const std::optional<torch::Tensor>& z_,
|
||||
const std::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
|
||||
const std::optional<torch::Tensor>& query_start_loc,
|
||||
const std::optional<torch::Tensor>& cache_indices,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
|
||||
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
|
||||
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
|
||||
const std::optional<torch::Tensor>& initial_state_idx);
|
||||
|
||||
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||
|
||||
@@ -611,7 +611,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor? cache_indices,"
|
||||
"Tensor? has_initial_state,"
|
||||
"Tensor! ssm_states,"
|
||||
"int pad_slot_id) -> ()");
|
||||
"int pad_slot_id,"
|
||||
"int block_size,"
|
||||
"Tensor? block_idx_first_scheduled_token,"
|
||||
"Tensor? block_idx_last_scheduled_token,"
|
||||
"Tensor? initial_state_idx) -> ()");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
// Hadamard transforms
|
||||
|
||||
@@ -179,6 +179,10 @@ def selective_scan_opcheck_fn(
|
||||
has_initial_state=None,
|
||||
ssm_states=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
block_size=2048,
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
initial_state_idx=None,
|
||||
):
|
||||
"""if return_last_state is True, returns (out, last_state)
|
||||
last_state has shape (batch, dim, dstate).
|
||||
@@ -223,6 +227,10 @@ def selective_scan_opcheck_fn(
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx,
|
||||
),
|
||||
test_utils=["test_schema", "test_faketensor"],
|
||||
)
|
||||
@@ -338,6 +346,11 @@ def test_selective_scan(
|
||||
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
|
||||
if c > 0
|
||||
else None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
block_size=2048,
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
initial_state_idx=None,
|
||||
)
|
||||
outs.append(out)
|
||||
if len(outs) > 1:
|
||||
@@ -372,6 +385,7 @@ def test_selective_scan(
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
ssm_states=state,
|
||||
block_size=2048,
|
||||
)
|
||||
|
||||
|
||||
@@ -586,6 +600,7 @@ def test_selective_scan_varlen(
|
||||
padded_state_indices,
|
||||
has_initial_state,
|
||||
prev_state,
|
||||
block_size=2048,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ pytestmark = pytest.mark.hybrid_model
|
||||
# meaning that it will be used in all tests in this file
|
||||
# The rest of the models will only be tested by test_models
|
||||
|
||||
APC_MULTIPLY_BY = 300
|
||||
|
||||
SSM_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"tiiuae/falcon-mamba-tiny-dev",
|
||||
@@ -380,7 +382,7 @@ def _get_vLLM_output(
|
||||
return outs, vllm_model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
@@ -410,10 +412,8 @@ def test_apc_single_prompt(
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts.
|
||||
generated_prompts = [MULTIPLE * example_prompts[0]]
|
||||
generated_prompts = [APC_MULTIPLY_BY * example_prompts[0]]
|
||||
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
@@ -446,7 +446,7 @@ def test_apc_single_prompt(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
@@ -476,10 +476,8 @@ def test_apc_single_prompt_block_align_alignment(
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts. This custom prompt is used, as it causes the most issues
|
||||
generated_prompts = ["The president of the United States is " * MULTIPLE]
|
||||
generated_prompts = ["The president of the United States is " * APC_MULTIPLY_BY]
|
||||
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
@@ -528,7 +526,7 @@ def test_apc_single_prompt_block_align_alignment(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
@@ -558,10 +556,8 @@ def test_apc_multiple_prompts_all_cached_outputs(
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts.
|
||||
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
|
||||
generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
|
||||
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
@@ -595,7 +591,7 @@ def test_apc_multiple_prompts_all_cached_outputs(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
@@ -625,12 +621,12 @@ def test_apc_multiple_prompts_block_align_alignment(
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts. This custom prompt is used, as it causes the most issues
|
||||
prompt_text = "The president of the United States is "
|
||||
prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31]
|
||||
generated_prompts = [prompt_text[offset:] * MULTIPLE for offset in prompt_offsets]
|
||||
generated_prompts = [
|
||||
prompt_text[offset:] * APC_MULTIPLY_BY for offset in prompt_offsets
|
||||
]
|
||||
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
@@ -679,7 +675,7 @@ def test_apc_multiple_prompts_block_align_alignment(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
@@ -709,10 +705,8 @@ def test_apc_multiple_prompts_partial_cached_outputs(
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts.
|
||||
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
|
||||
generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
|
||||
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
|
||||
@@ -1719,6 +1719,10 @@ def selective_scan_fwd(
|
||||
has_initial_state: torch.Tensor | None,
|
||||
ssm_states: torch.Tensor,
|
||||
pad_slot_id: int,
|
||||
block_size: int = 1024,
|
||||
block_idx_first_scheduled_token: torch.Tensor | None = None,
|
||||
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
||||
initial_state_idx: torch.Tensor | None = None,
|
||||
):
|
||||
torch.ops._C.selective_scan_fwd(
|
||||
u,
|
||||
@@ -1735,6 +1739,10 @@ def selective_scan_fwd(
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1483,6 +1483,12 @@ class ModelConfig:
|
||||
if chunk_size is None:
|
||||
# used by e.g. Mamba2, NemotronH, Zamba
|
||||
chunk_size = getattr(self.hf_text_config, "chunk_size", None)
|
||||
|
||||
# Since Mamba1 does not have a chunk notion
|
||||
# we use a default chunk size of 1024.
|
||||
if chunk_size is None:
|
||||
chunk_size = 2048
|
||||
|
||||
return chunk_size
|
||||
|
||||
def get_multimodal_config(self) -> MultiModalConfig:
|
||||
|
||||
@@ -241,18 +241,21 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
prefix_caching_enabled = self.cache_config.enable_prefix_caching
|
||||
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba1_metadata = attn_metadata
|
||||
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc = mamba1_metadata.query_start_loc
|
||||
state_indices_tensor = mamba1_metadata.state_indices_tensor
|
||||
assert isinstance(attn_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states = mamba1_metadata.has_initial_states
|
||||
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
num_padded_decodes = attn_metadata.num_padded_decodes
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@@ -279,12 +282,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
hidden_states_BC,
|
||||
gate,
|
||||
state_indices_tensor,
|
||||
query_start_loc,
|
||||
has_initial_states,
|
||||
num_prefill_tokens,
|
||||
num_decode_tokens,
|
||||
num_prefills,
|
||||
num_decodes,
|
||||
num_padded_decodes,
|
||||
)
|
||||
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
|
||||
@@ -293,8 +292,34 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
gate_d = prefill_decode_split.gate_d
|
||||
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
|
||||
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
|
||||
query_start_loc_p = prefill_decode_split.query_start_loc_p
|
||||
has_initial_states_p = prefill_decode_split.has_initial_states_p
|
||||
|
||||
if prefix_caching_enabled:
|
||||
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
|
||||
torch.split(
|
||||
attn_metadata.block_idx_last_computed_token,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
|
||||
torch.split(
|
||||
attn_metadata.block_idx_last_scheduled_token,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
|
||||
block_idx_first_scheduled_token_p = (
|
||||
attn_metadata.block_idx_first_scheduled_token_p
|
||||
)
|
||||
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
|
||||
else:
|
||||
block_idx_last_computed_token_d = None
|
||||
block_idx_last_computed_token_p = None
|
||||
block_idx_last_scheduled_token_d = None
|
||||
block_idx_last_scheduled_token_p = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
num_computed_tokens_p = None
|
||||
|
||||
ssm_outputs = []
|
||||
|
||||
@@ -309,6 +334,11 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
query_start_loc=query_start_loc_p,
|
||||
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
||||
initial_state_idx=block_idx_last_computed_token_p,
|
||||
num_computed_tokens=num_computed_tokens_p,
|
||||
block_size_to_align=mamba_block_size,
|
||||
)
|
||||
# 3. State Space Model sequence transformations.
|
||||
discrete_time_step_p, B_p, C_p = self._ssm_transform(
|
||||
@@ -331,10 +361,24 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
cache_indices=state_indices_tensor_p,
|
||||
has_initial_state=has_initial_states_p,
|
||||
query_start_loc=query_start_loc_p,
|
||||
block_size=mamba_block_size,
|
||||
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
||||
initial_state_idx=block_idx_last_computed_token_p,
|
||||
)
|
||||
ssm_outputs.append(scan_out_p)
|
||||
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
state_indices_tensor_d_output = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_scheduled_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
else:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d
|
||||
state_indices_tensor_d_output = state_indices_tensor_d
|
||||
# 2. Convolution sequence transformation
|
||||
conv_out_d = causal_conv1d_update(
|
||||
hidden_states_BC_d.transpose(0, 1),
|
||||
@@ -343,6 +387,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
|
||||
initial_state_idx=block_idx_last_computed_token_d,
|
||||
).transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation.
|
||||
@@ -364,7 +410,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
gate_d.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
state_batch_indices=state_indices_tensor_d_input,
|
||||
dst_state_batch_indices=state_indices_tensor_d_output,
|
||||
out=scan_outputs_d,
|
||||
)
|
||||
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||
@@ -423,20 +470,14 @@ class PrefillDecodeSplit(NamedTuple):
|
||||
gate_d: torch.Tensor
|
||||
state_indices_tensor_p: torch.Tensor
|
||||
state_indices_tensor_d: torch.Tensor
|
||||
query_start_loc_p: torch.Tensor | None
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
|
||||
|
||||
def split_batch_to_prefill_and_decode(
|
||||
hidden_states_BC: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
has_initial_states: torch.Tensor | None,
|
||||
num_prefill_tokens: int,
|
||||
num_decode_tokens: int,
|
||||
num_prefills: int,
|
||||
num_decodes: int,
|
||||
num_padded_decodes: int,
|
||||
) -> PrefillDecodeSplit:
|
||||
num_actual_tokens = num_prefill_tokens + num_padded_decodes
|
||||
@@ -457,16 +498,6 @@ def split_batch_to_prefill_and_decode(
|
||||
[num_padded_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
query_start_loc[-num_prefills - 1 :] - num_padded_decodes
|
||||
if num_prefills > 0
|
||||
else None
|
||||
)
|
||||
has_initial_states_p = (
|
||||
has_initial_states[-num_prefills:]
|
||||
if (has_initial_states is not None and num_prefills > 0)
|
||||
else None
|
||||
)
|
||||
|
||||
return PrefillDecodeSplit(
|
||||
hidden_states_BC_p=hidden_states_BC_p,
|
||||
@@ -475,8 +506,6 @@ def split_batch_to_prefill_and_decode(
|
||||
gate_d=gate_d,
|
||||
state_indices_tensor_p=state_indices_tensor_p,
|
||||
state_indices_tensor_d=state_indices_tensor_d,
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -375,6 +375,10 @@ def selective_scan_fn(
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
block_size=1024,
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
initial_state_idx=None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
@@ -397,7 +401,10 @@ def selective_scan_fn(
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
A tensor with each cell is a correspondent
|
||||
input and output ssm_state index
|
||||
input and output ssm_state indices
|
||||
- Without APC: (batch,) - single state index per batch item
|
||||
- With APC: (batch, max_positions) - cache block indices for read/write
|
||||
Each non-zero value indicates a cache block to load from and/or write to.
|
||||
has_initial_state: (batch) bool
|
||||
A tensor populated with ones and zeros,
|
||||
indicate if the ssm_state at the corresponding index should be
|
||||
@@ -408,6 +415,17 @@ def selective_scan_fn(
|
||||
that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at indices 0 and 3
|
||||
block_size: int
|
||||
The block size to align the cached states to
|
||||
block_idx_first_scheduled_token: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the first
|
||||
cache block to be filled is located.
|
||||
block_idx_last_scheduled_token: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the last cache block
|
||||
to be filled is located.
|
||||
initial_state_idx: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the cache block
|
||||
containing the initial state is located.
|
||||
returns
|
||||
output: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
supports inplace replacement
|
||||
@@ -448,6 +466,10 @@ def selective_scan_fn(
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx,
|
||||
)
|
||||
|
||||
if z is None:
|
||||
|
||||
@@ -299,7 +299,7 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
if model_config.supports_mamba_prefix_caching:
|
||||
logger.info(
|
||||
"Warning: Prefix caching is currently enabled. "
|
||||
"Its support for Mamba2 layers is experimental. "
|
||||
"Its support for Mamba layers is experimental. "
|
||||
"Please report any issues you may observe."
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -38,7 +38,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaMLP as JambaMLP
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||
from .interfaces import (
|
||||
HasInnerState,
|
||||
IsHybrid,
|
||||
SupportsLoRA,
|
||||
SupportsMambaPrefixCaching,
|
||||
SupportsPP,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
@@ -454,7 +460,14 @@ class JambaModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid):
|
||||
class JambaForCausalLM(
|
||||
nn.Module,
|
||||
HasInnerState,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
IsHybrid,
|
||||
SupportsMambaPrefixCaching,
|
||||
):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={".self_attn.": ".", ".A_log": ".A"},
|
||||
)
|
||||
@@ -477,12 +490,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHyb
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, (
|
||||
"Jamba currently does not support prefix caching"
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -29,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
HasInnerState,
|
||||
IsAttentionFree,
|
||||
SupportsMambaPrefixCaching,
|
||||
SupportsPP,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -193,15 +194,13 @@ class MambaModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
class MambaForCausalLM(
|
||||
nn.Module, HasInnerState, IsAttentionFree, SupportsPP, SupportsMambaPrefixCaching
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, (
|
||||
"Mamba does not support prefix caching"
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -7,11 +7,13 @@ import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class Mamba1AttentionBackend(AttentionBackend):
|
||||
@@ -22,32 +24,41 @@ class Mamba1AttentionBackend(AttentionBackend):
|
||||
|
||||
@dataclass
|
||||
class Mamba1AttentionMetadata:
|
||||
query_start_loc: torch.Tensor
|
||||
context_lens_tensor: torch.Tensor
|
||||
query_start_loc_p: torch.Tensor
|
||||
state_indices_tensor: torch.Tensor
|
||||
has_initial_states: torch.Tensor | None
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_padded_decodes: int
|
||||
|
||||
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
|
||||
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
|
||||
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
|
||||
num_computed_tokens_p: torch.Tensor # shape: [batch,]
|
||||
|
||||
|
||||
class Mamba1AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> Mamba1AttentionMetadata:
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
query_start_loc.device
|
||||
)
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
@@ -55,32 +66,100 @@ class Mamba1AttentionMetadataBuilder(
|
||||
)
|
||||
)
|
||||
|
||||
has_initial_states = None
|
||||
has_initial_states_p = None
|
||||
query_start_loc_p = None
|
||||
padded_decodes = num_decodes
|
||||
num_computed_tokens, num_computed_tokens_p = None, None
|
||||
block_idx_first_scheduled_token = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
|
||||
# TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here.
|
||||
# We should consolidate this code
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||
mamba_block_size = self.kv_cache_spec.block_size
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
(
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
) = self._compute_prefix_caching_block_indices(
|
||||
common_attn_metadata, mamba_block_size
|
||||
)
|
||||
else:
|
||||
# Always return just a single block per each request:
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
block_idx_last_scheduled_token = None
|
||||
block_idx_last_computed_token = None
|
||||
|
||||
if num_prefills > 0:
|
||||
has_initial_states = context_lens_tensor > 0
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
> 0
|
||||
)
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
common_attn_metadata.query_start_loc.device
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
assert num_computed_tokens is not None
|
||||
num_computed_tokens_p = num_computed_tokens[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
assert block_idx_first_scheduled_token is not None
|
||||
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
|
||||
elif (
|
||||
num_decodes > 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.full_cuda_graph
|
||||
):
|
||||
state_indices_for_decode = state_indices_tensor[:num_decodes]
|
||||
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
state_indices_for_decode, non_blocking=True
|
||||
state_indices_tensor, non_blocking=True
|
||||
)
|
||||
state_indices_tensor = self.state_indices_tensor[:padded_decodes]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
||||
block_idx_last_scheduled_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
||||
:padded_decodes
|
||||
]
|
||||
block_idx_last_scheduled_token[num_decodes:] = 0
|
||||
|
||||
self.block_idx_last_computed_token[:num_decodes].copy_(
|
||||
block_idx_last_computed_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
||||
:padded_decodes
|
||||
]
|
||||
block_idx_last_computed_token[num_decodes:] = 0
|
||||
|
||||
return Mamba1AttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
has_initial_states=has_initial_states,
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_padded_decodes=padded_decodes,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
)
|
||||
|
||||
@@ -147,27 +147,6 @@ class Mamba2AttentionMetadataBuilder(
|
||||
assert self.chunk_size is not None, (
|
||||
"chunk_size needs to be set in the model config for Mamba2 models"
|
||||
)
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(
|
||||
self.decode_cudagraph_max_bs,
|
||||
cdiv(
|
||||
vllm_config.model_config.max_model_len, kv_cache_spec.block_size
|
||||
),
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_scheduled_token = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_computed_token = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
@@ -202,20 +181,13 @@ class Mamba2AttentionMetadataBuilder(
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
# Block index of the last computed token
|
||||
block_idx_last_computed_token = (
|
||||
cdiv(num_computed_tokens, mamba_block_size) - 1
|
||||
(
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
) = self._compute_prefix_caching_block_indices(
|
||||
common_attn_metadata, mamba_block_size
|
||||
)
|
||||
# which is <= block index for the first scheduled token
|
||||
block_idx_first_scheduled_token = (
|
||||
cdiv(num_computed_tokens + 1, mamba_block_size) - 1
|
||||
)
|
||||
# which is <= block index of the last scheduled token
|
||||
block_idx_last_scheduled_token = (
|
||||
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
|
||||
)
|
||||
# -1 in case it's non-computed and causes later issues with indexing
|
||||
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
|
||||
else:
|
||||
# Always return just a single block per each request:
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import ClassVar, TypeVar
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
@@ -38,11 +39,35 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
self.vllm_config.scheduler_config.max_num_seqs,
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(
|
||||
self.decode_cudagraph_max_bs,
|
||||
cdiv(
|
||||
self.vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size,
|
||||
),
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_scheduled_token = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_computed_token = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
@@ -61,3 +86,30 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def _compute_prefix_caching_block_indices(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
mamba_block_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
# Block index of the last computed token
|
||||
block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
|
||||
# which is <= block index for the first scheduled token
|
||||
block_idx_first_scheduled_token = (
|
||||
cdiv(num_computed_tokens + 1, mamba_block_size) - 1
|
||||
)
|
||||
# which is <= block index of the last scheduled token
|
||||
block_idx_last_scheduled_token = (
|
||||
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
|
||||
)
|
||||
# -1 in case it's non-computed and causes later issues with indexing
|
||||
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
|
||||
|
||||
return (
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user