[V1] [Hybrid] Mamba1 Automatic Prefix Caching (#26377)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
Asaf Joseph Gardin
2025-11-02 14:16:23 +02:00
committed by GitHub
parent 73444b7b56
commit 00b31a36a2
16 changed files with 442 additions and 153 deletions

View File

@@ -24,6 +24,8 @@ struct SSMParamsBase {
int64_t pad_slot_id;
bool delta_softplus;
bool cache_enabled;
int block_size;
index_t A_d_stride;
index_t A_dstate_stride;
@@ -46,8 +48,9 @@ struct SSMParamsBase {
index_t out_z_batch_stride;
index_t out_z_d_stride;
index_t ssm_states_batch_stride;
index_t ssm_states_dim_stride;
index_t ssm_states_dim_stride;
index_t ssm_states_dstate_stride;
index_t cache_indices_stride;
// Common data pointers.
void *__restrict__ A_ptr;
@@ -66,6 +69,9 @@ struct SSMParamsBase {
void *__restrict__ cache_indices_ptr;
void *__restrict__ has_initial_state_ptr;
void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write
void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write
void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use
};

View File

@@ -119,7 +119,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
@@ -133,9 +133,18 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
typename Ktraits::state_t *ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
cache_index * params.ssm_states_batch_stride +
dim_id * kNRows * params.ssm_states_dim_stride;
typename Ktraits::state_t *ssm_states;
if (params.cache_enabled) {
// APC mode: ssm_states points to the base, we'll use absolute cache slots later
ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
dim_id * kNRows * params.ssm_states_dim_stride;
} else {
// Non-APC mode: offset by cache_index as before
ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
cache_index * params.ssm_states_batch_stride +
dim_id * kNRows * params.ssm_states_dim_stride;
}
float D_val[kNRows] = {0};
if (params.D_ptr != nullptr) {
@@ -159,7 +168,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// }
constexpr int kChunkSize = kNThreads * kNItems;
const int n_chunks = (seqlen + 2048 - 1) / 2048;
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048;
const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size;
const int* batch_cache_indices = cache_indices != nullptr ?
cache_indices + batch_id * params.cache_indices_stride : nullptr;
const int* block_idx_first_scheduled = params.block_idx_first_scheduled_token_ptr != nullptr ?
reinterpret_cast<const int*>(params.block_idx_first_scheduled_token_ptr) : nullptr;
const int* block_idx_last_scheduled = params.block_idx_last_scheduled_token_ptr != nullptr ?
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
for (int chunk = 0; chunk < n_chunks; ++chunk) {
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
@@ -219,7 +243,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (kIsVariableC) {
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 ));
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1));
if constexpr (!kIsVariableB) {
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
@@ -242,7 +266,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
for (int i = 0; i < kNItems; ++i) {
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
thread_data[i] = make_float2(1.f, 0.f);
@@ -250,8 +273,24 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
// Initialize running total
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0);
scan_t running_prefix;
if (chunk > 0) {
running_prefix = smem_running_prefix[state_idx + r * MAX_DSTATE];
} else {
// Load initial state
if (params.cache_enabled && has_initial_state && batch_cache_indices != nullptr) {
size_t state_offset = load_cache_slot * params.ssm_states_batch_stride +
r * params.ssm_states_dim_stride +
state_idx * params.ssm_states_dstate_stride;
running_prefix = make_float2(1.0, float(ssm_states[state_offset]));
} else if (has_initial_state) {
// Non-APC mode: load from current batch position
running_prefix = make_float2(1.0, float(ssm_states[state_idx * params.ssm_states_dstate_stride]));
} else {
// No initial state
running_prefix = make_float2(1.0, 0.0);
}
}
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
@@ -260,8 +299,25 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// There's a syncthreads in the scan op, so we don't need to sync here.
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
if (threadIdx.x == 0) {
smem_running_prefix[state_idx] = prefix_op.running_prefix;
if (chunk == n_chunks - 1) {
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
// Store state at the end of each chunk when cache is enabled
if (params.cache_enabled && batch_cache_indices != nullptr) {
size_t cache_slot;
if (chunk == n_chunks - 1) {
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
} else {
cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk];
}
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
r * params.ssm_states_dim_stride +
state_idx * params.ssm_states_dstate_stride;
ssm_states[state_offset] = typename Ktraits::state_t(prefix_op.running_prefix.y);
} else if (!params.cache_enabled && chunk == n_chunks - 1) {
// Non-APC mode: store only final state at current batch position
ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y);
}
}
@@ -274,7 +330,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
}
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
__syncthreads();
@@ -346,7 +401,9 @@ template<typename input_t, typename weight_t, typename state_t>
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
#ifndef USE_ROCM
if (params.seqlen <= 128) {
if (params.cache_enabled && params.block_size == 1024) {
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 128) {
selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 256) {
selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream);
@@ -358,7 +415,9 @@ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
}
#else
if (params.seqlen <= 256) {
if (params.cache_enabled && params.block_size == 1024) {
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 256) {
selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 512) {
selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream);
@@ -437,13 +496,17 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const std::optional<at::Tensor>& D,
const std::optional<at::Tensor>& delta_bias,
const torch::Tensor ssm_states,
bool has_z,
bool has_z,
bool delta_softplus,
const std::optional<at::Tensor>& query_start_loc,
const std::optional<at::Tensor>& cache_indices,
const std::optional<at::Tensor>& has_initial_state,
bool varlen,
int64_t pad_slot_id) {
int64_t pad_slot_id,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx) {
// Reset the parameters
memset(&params, 0, sizeof(params));
@@ -477,6 +540,14 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
// Set cache parameters - cache is enabled if we have direct cache writing params
params.cache_enabled = block_idx_first_scheduled_token.has_value();
params.block_size = static_cast<int>(block_size);
// Set direct cache writing pointers
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
// All stride are in elements, not bytes.
params.A_d_stride = A.stride(0);
@@ -504,9 +575,11 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.out_d_stride = out.stride(0);
params.ssm_states_batch_stride = ssm_states.stride(0);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dstate_stride = ssm_states.stride(2);
params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0;
}
else{
if (!is_variable_B) {
@@ -537,8 +610,10 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.out_d_stride = out.stride(1);
params.ssm_states_batch_stride = ssm_states.stride(0);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dstate_stride = ssm_states.stride(2);
params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0;
}
}
@@ -554,7 +629,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
int64_t pad_slot_id,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -646,7 +725,16 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
auto cache_indices_ = cache_indices.value();
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(cache_indices_.is_cuda());
CHECK_SHAPE(cache_indices_, batch_size);
// cache_indices can be either 1D (batch_size,) for non-APC mode
// or 2D (batch_size, max_positions) for APC mode
const bool is_apc_mode = block_idx_first_scheduled_token.has_value();
if (is_apc_mode) {
TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode");
TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size");
} else {
CHECK_SHAPE(cache_indices_, batch_size);
}
}
@@ -686,7 +774,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
cache_indices,
has_initial_state,
varlen,
pad_slot_id
pad_slot_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx
);

View File

@@ -321,17 +321,19 @@ void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
std::optional<torch::Tensor> const& scale_ub);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B,
const torch::Tensor& C,
const std::optional<torch::Tensor>& D_,
const std::optional<torch::Tensor>& z_,
const std::optional<torch::Tensor>& delta_bias_,
bool delta_softplus,
const std::optional<torch::Tensor>& query_start_loc,
const std::optional<torch::Tensor>& cache_indices,
const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t pad_slot_id);
void selective_scan_fwd(
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
const torch::Tensor& B, const torch::Tensor& C,
const std::optional<torch::Tensor>& D_,
const std::optional<torch::Tensor>& z_,
const std::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
const std::optional<torch::Tensor>& query_start_loc,
const std::optional<torch::Tensor>& cache_indices,
const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
const std::optional<torch::Tensor>& initial_state_idx);
torch::Tensor dynamic_4bit_int_moe_cpu(
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,

View File

@@ -611,7 +611,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states,"
"int pad_slot_id) -> ()");
"int pad_slot_id,"
"int block_size,"
"Tensor? block_idx_first_scheduled_token,"
"Tensor? block_idx_last_scheduled_token,"
"Tensor? initial_state_idx) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
// Hadamard transforms

View File

@@ -179,6 +179,10 @@ def selective_scan_opcheck_fn(
has_initial_state=None,
ssm_states=None,
pad_slot_id=PAD_SLOT_ID,
block_size=2048,
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
initial_state_idx=None,
):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
@@ -223,6 +227,10 @@ def selective_scan_opcheck_fn(
has_initial_state,
ssm_states,
pad_slot_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
),
test_utils=["test_schema", "test_faketensor"],
)
@@ -338,6 +346,11 @@ def test_selective_scan(
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
if c > 0
else None,
pad_slot_id=PAD_SLOT_ID,
block_size=2048,
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
initial_state_idx=None,
)
outs.append(out)
if len(outs) > 1:
@@ -372,6 +385,7 @@ def test_selective_scan(
delta_bias=delta_bias,
delta_softplus=delta_softplus,
ssm_states=state,
block_size=2048,
)
@@ -586,6 +600,7 @@ def test_selective_scan_varlen(
padded_state_indices,
has_initial_state,
prev_state,
block_size=2048,
)

View File

@@ -19,6 +19,8 @@ pytestmark = pytest.mark.hybrid_model
# meaning that it will be used in all tests in this file
# The rest of the models will only be tested by test_models
APC_MULTIPLY_BY = 300
SSM_MODELS = [
"state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev",
@@ -380,7 +382,7 @@ def _get_vLLM_output(
return outs, vllm_model
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
@@ -410,10 +412,8 @@ def test_apc_single_prompt(
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
)
MULTIPLE = 300
# Sample prompts.
generated_prompts = [MULTIPLE * example_prompts[0]]
generated_prompts = [APC_MULTIPLY_BY * example_prompts[0]]
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
@@ -446,7 +446,7 @@ def test_apc_single_prompt(
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
@@ -476,10 +476,8 @@ def test_apc_single_prompt_block_align_alignment(
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
)
MULTIPLE = 300
# Sample prompts. This custom prompt is used, as it causes the most issues
generated_prompts = ["The president of the United States is " * MULTIPLE]
generated_prompts = ["The president of the United States is " * APC_MULTIPLY_BY]
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
@@ -528,7 +526,7 @@ def test_apc_single_prompt_block_align_alignment(
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
@@ -558,10 +556,8 @@ def test_apc_multiple_prompts_all_cached_outputs(
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
)
MULTIPLE = 300
# Sample prompts.
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
@@ -595,7 +591,7 @@ def test_apc_multiple_prompts_all_cached_outputs(
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
@@ -625,12 +621,12 @@ def test_apc_multiple_prompts_block_align_alignment(
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
)
MULTIPLE = 300
# Sample prompts. This custom prompt is used, as it causes the most issues
prompt_text = "The president of the United States is "
prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31]
generated_prompts = [prompt_text[offset:] * MULTIPLE for offset in prompt_offsets]
generated_prompts = [
prompt_text[offset:] * APC_MULTIPLY_BY for offset in prompt_offsets
]
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
@@ -679,7 +675,7 @@ def test_apc_multiple_prompts_block_align_alignment(
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
@@ -709,10 +705,8 @@ def test_apc_multiple_prompts_partial_cached_outputs(
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
)
MULTIPLE = 300
# Sample prompts.
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(

View File

@@ -1719,6 +1719,10 @@ def selective_scan_fwd(
has_initial_state: torch.Tensor | None,
ssm_states: torch.Tensor,
pad_slot_id: int,
block_size: int = 1024,
block_idx_first_scheduled_token: torch.Tensor | None = None,
block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None,
):
torch.ops._C.selective_scan_fwd(
u,
@@ -1735,6 +1739,10 @@ def selective_scan_fwd(
has_initial_state,
ssm_states,
pad_slot_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
)

View File

@@ -1483,6 +1483,12 @@ class ModelConfig:
if chunk_size is None:
# used by e.g. Mamba2, NemotronH, Zamba
chunk_size = getattr(self.hf_text_config, "chunk_size", None)
# Since Mamba1 does not have a chunk notion
# we use a default chunk size of 1024.
if chunk_size is None:
chunk_size = 2048
return chunk_size
def get_multimodal_config(self) -> MultiModalConfig:

View File

@@ -241,18 +241,21 @@ class MambaMixer(MambaBase, CustomOp):
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
mamba1_metadata = attn_metadata
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
query_start_loc = mamba1_metadata.query_start_loc
state_indices_tensor = mamba1_metadata.state_indices_tensor
assert isinstance(attn_metadata, Mamba1AttentionMetadata)
query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor = attn_metadata.state_indices_tensor
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states = mamba1_metadata.has_initial_states
num_padded_decodes = mamba1_metadata.num_padded_decodes
has_initial_states_p = attn_metadata.has_initial_states_p
num_padded_decodes = attn_metadata.num_padded_decodes
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -279,12 +282,8 @@ class MambaMixer(MambaBase, CustomOp):
hidden_states_BC,
gate,
state_indices_tensor,
query_start_loc,
has_initial_states,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
num_decodes,
num_padded_decodes,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
@@ -293,8 +292,34 @@ class MambaMixer(MambaBase, CustomOp):
gate_d = prefill_decode_split.gate_d
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
query_start_loc_p = prefill_decode_split.query_start_loc_p
has_initial_states_p = prefill_decode_split.has_initial_states_p
if prefix_caching_enabled:
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
torch.split(
attn_metadata.block_idx_last_computed_token,
[num_decodes, num_prefills],
dim=0,
)
)
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
torch.split(
attn_metadata.block_idx_last_scheduled_token,
[num_decodes, num_prefills],
dim=0,
)
)
block_idx_first_scheduled_token_p = (
attn_metadata.block_idx_first_scheduled_token_p
)
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
else:
block_idx_last_computed_token_d = None
block_idx_last_computed_token_p = None
block_idx_last_scheduled_token_d = None
block_idx_last_scheduled_token_p = None
block_idx_first_scheduled_token_p = None
num_computed_tokens_p = None
ssm_outputs = []
@@ -309,6 +334,11 @@ class MambaMixer(MambaBase, CustomOp):
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p,
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
num_computed_tokens=num_computed_tokens_p,
block_size_to_align=mamba_block_size,
)
# 3. State Space Model sequence transformations.
discrete_time_step_p, B_p, C_p = self._ssm_transform(
@@ -331,10 +361,24 @@ class MambaMixer(MambaBase, CustomOp):
cache_indices=state_indices_tensor_p,
has_initial_state=has_initial_states_p,
query_start_loc=query_start_loc_p,
block_size=mamba_block_size,
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
)
ssm_outputs.append(scan_out_p)
if has_decode:
if prefix_caching_enabled:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)
state_indices_tensor_d_output = state_indices_tensor_d.gather(
1, block_idx_last_scheduled_token_d.unsqueeze(1)
).squeeze(1)
else:
state_indices_tensor_d_input = state_indices_tensor_d
state_indices_tensor_d_output = state_indices_tensor_d
# 2. Convolution sequence transformation
conv_out_d = causal_conv1d_update(
hidden_states_BC_d.transpose(0, 1),
@@ -343,6 +387,8 @@ class MambaMixer(MambaBase, CustomOp):
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
initial_state_idx=block_idx_last_computed_token_d,
).transpose(0, 1)
# 3. State Space Model sequence transformation.
@@ -364,7 +410,8 @@ class MambaMixer(MambaBase, CustomOp):
gate_d.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=scan_outputs_d,
)
scan_outputs_d = scan_outputs_d.transpose(0, 1)
@@ -423,20 +470,14 @@ class PrefillDecodeSplit(NamedTuple):
gate_d: torch.Tensor
state_indices_tensor_p: torch.Tensor
state_indices_tensor_d: torch.Tensor
query_start_loc_p: torch.Tensor | None
has_initial_states_p: torch.Tensor | None
def split_batch_to_prefill_and_decode(
hidden_states_BC: torch.Tensor,
gate: torch.Tensor,
state_indices_tensor: torch.Tensor,
query_start_loc: torch.Tensor,
has_initial_states: torch.Tensor | None,
num_prefill_tokens: int,
num_decode_tokens: int,
num_prefills: int,
num_decodes: int,
num_padded_decodes: int,
) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes
@@ -457,16 +498,6 @@ def split_batch_to_prefill_and_decode(
[num_padded_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
query_start_loc[-num_prefills - 1 :] - num_padded_decodes
if num_prefills > 0
else None
)
has_initial_states_p = (
has_initial_states[-num_prefills:]
if (has_initial_states is not None and num_prefills > 0)
else None
)
return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p,
@@ -475,8 +506,6 @@ def split_batch_to_prefill_and_decode(
gate_d=gate_d,
state_indices_tensor_p=state_indices_tensor_p,
state_indices_tensor_d=state_indices_tensor_d,
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
)

View File

@@ -375,6 +375,10 @@ def selective_scan_fn(
cache_indices=None,
has_initial_state=None,
pad_slot_id=PAD_SLOT_ID,
block_size=1024,
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
initial_state_idx=None,
) -> torch.Tensor:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
@@ -397,7 +401,10 @@ def selective_scan_fn(
x.shape=(dim,17)
cache_indices: (batch) int32
A tensor with each cell is a correspondent
input and output ssm_state index
input and output ssm_state indices
- Without APC: (batch,) - single state index per batch item
- With APC: (batch, max_positions) - cache block indices for read/write
Each non-zero value indicates a cache block to load from and/or write to.
has_initial_state: (batch) bool
A tensor populated with ones and zeros,
indicate if the ssm_state at the corresponding index should be
@@ -408,6 +415,17 @@ def selective_scan_fn(
that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at indices 0 and 3
block_size: int
The block size to align the cached states to
block_idx_first_scheduled_token: (batch,), dtype int32
The pointer into cache_indices, where the first
cache block to be filled is located.
block_idx_last_scheduled_token: (batch,), dtype int32
The pointer into cache_indices, where the last cache block
to be filled is located.
initial_state_idx: (batch,), dtype int32
The pointer into cache_indices, where the cache block
containing the initial state is located.
returns
output: (dim, total_length) for varlen or (batch, dim, seqlen)
supports inplace replacement
@@ -448,6 +466,10 @@ def selective_scan_fn(
has_initial_state,
ssm_states,
pad_slot_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
)
if z is None:

View File

@@ -299,7 +299,7 @@ class MambaModelConfig(VerifyAndUpdateConfig):
if model_config.supports_mamba_prefix_caching:
logger.info(
"Warning: Prefix caching is currently enabled. "
"Its support for Mamba2 layers is experimental. "
"Its support for Mamba layers is experimental. "
"Please report any issues you may observe."
)
else:

View File

@@ -38,7 +38,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaMLP as JambaMLP
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .interfaces import (
HasInnerState,
IsHybrid,
SupportsLoRA,
SupportsMambaPrefixCaching,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
@@ -454,7 +460,14 @@ class JambaModel(nn.Module):
return loaded_params
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid):
class JambaForCausalLM(
nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
IsHybrid,
SupportsMambaPrefixCaching,
):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".self_attn.": ".", ".A_log": ".A"},
)
@@ -477,12 +490,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHyb
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, (
"Jamba currently does not support prefix caching"
)
super().__init__()
self.config = config

View File

@@ -29,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (
HasInnerState,
IsAttentionFree,
SupportsMambaPrefixCaching,
SupportsPP,
)
from vllm.sequence import IntermediateTensors
@@ -193,15 +194,13 @@ class MambaModel(nn.Module):
return loaded_params
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
class MambaForCausalLM(
nn.Module, HasInnerState, IsAttentionFree, SupportsPP, SupportsMambaPrefixCaching
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
self.scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, (
"Mamba does not support prefix caching"
)
super().__init__()
self.config = config

View File

@@ -7,11 +7,13 @@ import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class Mamba1AttentionBackend(AttentionBackend):
@@ -22,32 +24,41 @@ class Mamba1AttentionBackend(AttentionBackend):
@dataclass
class Mamba1AttentionMetadata:
query_start_loc: torch.Tensor
context_lens_tensor: torch.Tensor
query_start_loc_p: torch.Tensor
state_indices_tensor: torch.Tensor
has_initial_states: torch.Tensor | None
has_initial_states_p: torch.Tensor | None
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_padded_decodes: int
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
num_computed_tokens_p: torch.Tensor # shape: [batch,]
class Mamba1AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> Mamba1AttentionMetadata:
query_start_loc = common_attn_metadata.query_start_loc
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
query_start_loc.device
)
num_reqs = common_attn_metadata.num_reqs
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
@@ -55,32 +66,100 @@ class Mamba1AttentionMetadataBuilder(
)
)
has_initial_states = None
has_initial_states_p = None
query_start_loc_p = None
padded_decodes = num_decodes
num_computed_tokens, num_computed_tokens_p = None, None
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
# TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here.
# We should consolidate this code
if self.vllm_config.cache_config.enable_prefix_caching:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
block_idx_last_scheduled_token = None
block_idx_last_computed_token = None
if num_prefills > 0:
has_initial_states = context_lens_tensor > 0
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
)
if self.vllm_config.cache_config.enable_prefix_caching:
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
]
assert block_idx_first_scheduled_token is not None
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs
]
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.full_cuda_graph
):
state_indices_for_decode = state_indices_tensor[:num_decodes]
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(
state_indices_for_decode, non_blocking=True
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:padded_decodes]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching:
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:padded_decodes
]
block_idx_last_scheduled_token[num_decodes:] = 0
self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:padded_decodes
]
block_idx_last_computed_token[num_decodes:] = 0
return Mamba1AttentionMetadata(
query_start_loc=query_start_loc,
context_lens_tensor=context_lens_tensor,
has_initial_states=has_initial_states,
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
state_indices_tensor=state_indices_tensor,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_padded_decodes=padded_decodes,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
)

View File

@@ -147,27 +147,6 @@ class Mamba2AttentionMetadataBuilder(
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models"
)
if self.vllm_config.cache_config.enable_prefix_caching:
self.state_indices_tensor = torch.empty(
(
self.decode_cudagraph_max_bs,
cdiv(
vllm_config.model_config.max_model_len, kv_cache_spec.block_size
),
),
dtype=torch.int32,
device=device,
)
self.block_idx_last_scheduled_token = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.block_idx_last_computed_token = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
def build(
self,
@@ -202,20 +181,13 @@ class Mamba2AttentionMetadataBuilder(
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
# Block index of the last computed token
block_idx_last_computed_token = (
cdiv(num_computed_tokens, mamba_block_size) - 1
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
# which is <= block index for the first scheduled token
block_idx_first_scheduled_token = (
cdiv(num_computed_tokens + 1, mamba_block_size) - 1
)
# which is <= block index of the last scheduled token
block_idx_last_scheduled_token = (
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
)
# -1 in case it's non-computed and causes later issues with indexing
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]

View File

@@ -7,6 +7,7 @@ from typing import ClassVar, TypeVar
import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
@@ -38,11 +39,35 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_cudagraph_capture_size,
)
self.state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
if self.vllm_config.cache_config.enable_prefix_caching:
self.state_indices_tensor = torch.empty(
(
self.decode_cudagraph_max_bs,
cdiv(
self.vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size,
),
),
dtype=torch.int32,
device=device,
)
self.block_idx_last_scheduled_token = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.block_idx_last_computed_token = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
else:
self.state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
@@ -61,3 +86,30 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
m.max_query_len = 1 # decode-only
return self.build(0, m)
def _compute_prefix_caching_block_indices(
self,
common_attn_metadata: CommonAttentionMetadata,
mamba_block_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
# Block index of the last computed token
block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
# which is <= block index for the first scheduled token
block_idx_first_scheduled_token = (
cdiv(num_computed_tokens + 1, mamba_block_size) - 1
)
# which is <= block index of the last scheduled token
block_idx_last_scheduled_token = (
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
)
# -1 in case it's non-computed and causes later issues with indexing
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
return (
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
)