mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 15:04:47 +08:00
Signed-off-by: Ubuntu <ubuntu@ip-10-252-30-150.eu-west-1.compute.internal> Signed-off-by: Elham Harirpoush <elham.harirpoush@arm.com> Co-authored-by: Ubuntu <ubuntu@ip-10-252-30-150.eu-west-1.compute.internal>
2001 lines
83 KiB
C++
2001 lines
83 KiB
C++
#ifndef CPU_ATTN_HPP
|
|
#define CPU_ATTN_HPP
|
|
|
|
#include <type_traits>
|
|
#include <cstddef>
|
|
|
|
#if defined(__APPLE__)
|
|
#include <sys/sysctl.h>
|
|
#endif
|
|
|
|
#include "cpu_types.hpp"
|
|
#include "scratchpad_manager.h"
|
|
#include "cpu_attn_macros.h"
|
|
#include "utils.hpp"
|
|
|
|
namespace cpu_attention {
|
|
enum class ISA { AMX, VEC, VEC16, NEON };
|
|
|
|
template <ISA isa, typename scalar_t, int64_t head_dim>
|
|
class AttentionImpl {};
|
|
|
|
struct AttentionWorkItemGroup {
|
|
int32_t req_id;
|
|
int32_t q_token_id_start;
|
|
int32_t q_token_num;
|
|
int32_t kv_split_pos_start;
|
|
int32_t kv_split_pos_end;
|
|
|
|
int64_t total_kv_len;
|
|
int32_t split_id;
|
|
int32_t local_split_id;
|
|
|
|
AttentionWorkItemGroup(const int32_t req_id, const int32_t q_token_id_start,
|
|
const int32_t kv_split_pos_start,
|
|
const int32_t kv_split_pos_end)
|
|
: req_id(req_id),
|
|
q_token_id_start(q_token_id_start),
|
|
q_token_num(0),
|
|
kv_split_pos_start(kv_split_pos_start),
|
|
kv_split_pos_end(kv_split_pos_end),
|
|
total_kv_len(0),
|
|
split_id(-1),
|
|
local_split_id(0) {}
|
|
|
|
std::string to_string() const {
|
|
std::stringstream ss;
|
|
ss << '[' << "req_id: " << req_id << ",\n";
|
|
ss << "q_token_id_start: " << q_token_id_start << ",\n";
|
|
ss << "q_token_num: " << q_token_num << ",\n";
|
|
ss << "kv_split_pos_start: " << kv_split_pos_start << ",\n";
|
|
ss << "kv_split_pos_end: " << kv_split_pos_end << ",\n";
|
|
ss << "total_kv_len: " << total_kv_len << ",\n";
|
|
ss << "split_id: " << split_id << ",\n";
|
|
ss << "local_split_id: " << local_split_id << ",\n";
|
|
ss << ']';
|
|
|
|
return ss.str();
|
|
}
|
|
};
|
|
|
|
struct ReductionWorkItemGroup {
|
|
int32_t req_id;
|
|
int32_t q_token_id_start;
|
|
int32_t q_token_id_num;
|
|
int32_t split_start_id;
|
|
int32_t split_num;
|
|
|
|
ReductionWorkItemGroup(const int32_t req_id, const int32_t q_token_id_start,
|
|
const int32_t q_token_id_num,
|
|
const int32_t split_start_id)
|
|
: req_id(req_id),
|
|
q_token_id_start(q_token_id_start),
|
|
q_token_id_num(q_token_id_num),
|
|
split_start_id(split_start_id),
|
|
split_num(0) {}
|
|
|
|
std::string to_string() const {
|
|
std::stringstream ss;
|
|
ss << '[' << "req_id: " << req_id << ",\n";
|
|
ss << "q_token_id_start: " << q_token_id_start << ",\n";
|
|
ss << "q_token_id_num: " << q_token_id_num << ",\n";
|
|
ss << "split_start_id: " << split_start_id << ",\n";
|
|
ss << "split_num: " << split_num << ",\n";
|
|
ss << ']';
|
|
|
|
return ss.str();
|
|
}
|
|
};
|
|
|
|
struct AttentionMetadata {
|
|
std::atomic_int64_t counter;
|
|
char _padding1[56];
|
|
ISA isa;
|
|
int32_t workitem_group_num;
|
|
int32_t reduction_item_num;
|
|
int32_t reduction_split_num;
|
|
int32_t thread_num;
|
|
int32_t effective_thread_num; // non-zero item num in workitem_num_per_thread
|
|
int32_t split_kv_q_token_num_threshold;
|
|
int64_t attention_scratchpad_size_per_thread;
|
|
int64_t reduction_scratchpad_size_per_kv_head;
|
|
AttentionWorkItemGroup* workitem_groups_ptr;
|
|
ReductionWorkItemGroup* reduction_items_ptr;
|
|
int32_t cu_workitem_num_per_thread[1025] = {
|
|
0}; // prefix sum of workitem_num_per_thread
|
|
char _padding2[56];
|
|
|
|
AttentionMetadata(ISA isa, int32_t workitem_group_num,
|
|
int32_t reduction_item_num, int32_t reduction_split_num,
|
|
int32_t split_kv_q_token_num_threshold)
|
|
: isa(isa),
|
|
workitem_group_num(workitem_group_num),
|
|
reduction_item_num(reduction_item_num),
|
|
reduction_split_num(reduction_split_num),
|
|
thread_num(omp_get_max_threads()),
|
|
effective_thread_num(thread_num),
|
|
split_kv_q_token_num_threshold(split_kv_q_token_num_threshold),
|
|
attention_scratchpad_size_per_thread(0),
|
|
reduction_scratchpad_size_per_kv_head(0),
|
|
workitem_groups_ptr(
|
|
(AttentionWorkItemGroup*)((char*)this + sizeof(AttentionMetadata))),
|
|
reduction_items_ptr(
|
|
(ReductionWorkItemGroup*)((char*)this + sizeof(AttentionMetadata) +
|
|
workitem_group_num *
|
|
sizeof(AttentionWorkItemGroup))),
|
|
counter(0) {
|
|
TORCH_CHECK_LE(thread_num, 1024);
|
|
static_assert(sizeof(AttentionMetadata) % 64 == 0);
|
|
TORCH_CHECK(reinterpret_cast<size_t>(this) % 64 == 0);
|
|
}
|
|
|
|
void reset_counter() { counter.store(0); }
|
|
|
|
int64_t acquire_counter() { return counter++; }
|
|
|
|
void print() const {
|
|
std::stringstream ss;
|
|
ss << "ISA: ";
|
|
switch (isa) {
|
|
case ISA::AMX:
|
|
ss << "AMX, ";
|
|
break;
|
|
case ISA::VEC:
|
|
ss << "VEC, ";
|
|
break;
|
|
case ISA::VEC16:
|
|
ss << "VEC16, ";
|
|
break;
|
|
case ISA::NEON:
|
|
ss << "NEON, ";
|
|
break;
|
|
}
|
|
ss << "workitem_group_num: " << workitem_group_num
|
|
<< ", reduction_item_num: " << reduction_item_num
|
|
<< ", reduction_split_num: " << reduction_split_num
|
|
<< ", thread_num: " << thread_num
|
|
<< ", effective_thread_num: " << effective_thread_num
|
|
<< ", attention_scratchpad_size_per_thread: "
|
|
<< attention_scratchpad_size_per_thread
|
|
<< ", reduction_scratchpad_size_per_kv_head: "
|
|
<< reduction_scratchpad_size_per_kv_head << ", workitem groups:\n";
|
|
for (int32_t i = 0; i < workitem_group_num; ++i) {
|
|
ss << (workitem_groups_ptr + i)->to_string() << ",\n";
|
|
}
|
|
|
|
ss << "cu_workitem_num_per_thread: [";
|
|
for (int32_t i = 0; i < thread_num + 1; ++i) {
|
|
ss << cu_workitem_num_per_thread[i] << ", ";
|
|
}
|
|
ss << "]\n";
|
|
|
|
ss << "reduction items: \n";
|
|
|
|
for (int32_t i = 0; i < reduction_item_num; ++i) {
|
|
ss << (reduction_items_ptr + i)->to_string() << ",\n";
|
|
}
|
|
|
|
std::printf("%s", ss.str().c_str());
|
|
}
|
|
};
|
|
|
|
// Thread attention scratchpad contains:
|
|
// - Q: q_tile_size * head_dim * q_buffer_elem_size, gather Q heads, especially
|
|
// for GQA
|
|
// - Q@K^T: max_num_q_per_iter * k_tile_size * logits_buffer_elem_size, logits
|
|
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2
|
|
// * q_tile_size * 4, partial output, max + sum (float)
|
|
// Reduction scratchpad contains:
|
|
// - flags: bool array to indicate wether the split is finished
|
|
// - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size
|
|
// - max, sum: 2 * split_num * q_tile_size * 4
|
|
class AttentionScratchPad {
|
|
public:
|
|
AttentionScratchPad(int64_t thread_id,
|
|
const AttentionMetadata& attention_metadata,
|
|
void* scratchpad_ptr)
|
|
: thread_scratchpad_ptr(
|
|
static_cast<int8_t*>(scratchpad_ptr) +
|
|
thread_id *
|
|
attention_metadata.attention_scratchpad_size_per_thread),
|
|
reduction_scratchpad_ptr(
|
|
static_cast<int8_t*>(scratchpad_ptr) +
|
|
attention_metadata.thread_num *
|
|
attention_metadata.attention_scratchpad_size_per_thread),
|
|
reduction_scratchpad_size_per_kv_head(
|
|
attention_metadata.reduction_scratchpad_size_per_kv_head) {}
|
|
|
|
// for attention
|
|
void update(const int64_t head_dim, const int64_t q_buffer_elem_size,
|
|
const int64_t logits_buffer_elem_size,
|
|
const int64_t output_buffer_elem_size,
|
|
const int64_t max_num_q_per_iter, const int64_t q_head_tile_size,
|
|
const int64_t kv_tile_size) {
|
|
int64_t buffer_offset = 0;
|
|
q_buffer_offset_ = buffer_offset;
|
|
buffer_offset +=
|
|
calcu_q_buffer_size(q_head_tile_size, head_dim, q_buffer_elem_size);
|
|
logits_buffer_offset_ = buffer_offset;
|
|
buffer_offset += calcu_logits_buffer_size(max_num_q_per_iter, kv_tile_size,
|
|
logits_buffer_elem_size);
|
|
output_buffer_offset_ = buffer_offset;
|
|
buffer_offset += calcu_partial_output_buffer_size(
|
|
q_head_tile_size, head_dim, output_buffer_elem_size);
|
|
max_buffer_offset_ = buffer_offset;
|
|
buffer_offset += calcu_partial_output_max_sum_buffer_size(q_head_tile_size);
|
|
sum_buffer_offset_ = buffer_offset;
|
|
}
|
|
|
|
// for reduction
|
|
void update(const int32_t kv_head_idx, const int32_t total_split_num,
|
|
const int64_t head_dim, const int64_t q_head_tile_size,
|
|
const int64_t output_buffer_elem_size) {
|
|
int64_t buffer_offset = kv_head_idx * reduction_scratchpad_size_per_kv_head;
|
|
reduce_flag_buffer_offset_ = buffer_offset;
|
|
buffer_offset += calcu_reduce_flag_buffer_size(total_split_num);
|
|
reduce_output_buffer_offset_ = buffer_offset;
|
|
buffer_offset += calcu_reduce_output_buffer_size(
|
|
total_split_num, q_head_tile_size, head_dim, output_buffer_elem_size);
|
|
reduce_max_buffer_offset_ = buffer_offset;
|
|
buffer_offset +=
|
|
calcu_reduce_max_sum_buffer_size(total_split_num, q_head_tile_size);
|
|
reduce_sum_buffer_offset_ = buffer_offset;
|
|
}
|
|
|
|
template <typename T>
|
|
T* get_q_buffer() {
|
|
return reinterpret_cast<T*>(thread_scratchpad_ptr + q_buffer_offset_);
|
|
}
|
|
|
|
float* get_logits_buffer() {
|
|
return reinterpret_cast<float*>(thread_scratchpad_ptr +
|
|
logits_buffer_offset_);
|
|
}
|
|
|
|
float* get_output_buffer() {
|
|
return reinterpret_cast<float*>(thread_scratchpad_ptr +
|
|
output_buffer_offset_);
|
|
}
|
|
|
|
float* get_max_buffer() {
|
|
return reinterpret_cast<float*>(thread_scratchpad_ptr + max_buffer_offset_);
|
|
}
|
|
|
|
float* get_sum_buffer() {
|
|
return reinterpret_cast<float*>(thread_scratchpad_ptr + sum_buffer_offset_);
|
|
}
|
|
|
|
volatile bool* get_reduce_flag_buffer() {
|
|
return reinterpret_cast<volatile bool*>(reduction_scratchpad_ptr +
|
|
reduce_flag_buffer_offset_);
|
|
}
|
|
|
|
float* get_reduce_output_buffer() {
|
|
return reinterpret_cast<float*>(reduction_scratchpad_ptr +
|
|
reduce_output_buffer_offset_);
|
|
}
|
|
|
|
float* get_reduce_max_buffer() {
|
|
return reinterpret_cast<float*>(reduction_scratchpad_ptr +
|
|
reduce_max_buffer_offset_);
|
|
}
|
|
|
|
float* get_reduce_sum_buffer() {
|
|
return reinterpret_cast<float*>(reduction_scratchpad_ptr +
|
|
reduce_sum_buffer_offset_);
|
|
}
|
|
|
|
int64_t get_thread_scratchpad_size() const {
|
|
return 2 * sum_buffer_offset_ - max_buffer_offset_;
|
|
}
|
|
|
|
int64_t get_reduction_scratchpad_size() const {
|
|
return 2 * reduce_sum_buffer_offset_ - reduce_max_buffer_offset_;
|
|
}
|
|
|
|
private:
|
|
static int64_t round_to_64(const int64_t num) {
|
|
return ((num + 63) >> 6) << 6;
|
|
}
|
|
|
|
static int64_t calcu_q_buffer_size(const int64_t q_tile_size,
|
|
const int64_t head_dim,
|
|
const int64_t elem_size) {
|
|
return round_to_64(q_tile_size * head_dim * elem_size);
|
|
}
|
|
|
|
static int64_t calcu_logits_buffer_size(const int64_t max_num_q_per_iter,
|
|
const int64_t k_tile_size,
|
|
const int64_t elem_size) {
|
|
return round_to_64(elem_size * max_num_q_per_iter * k_tile_size);
|
|
}
|
|
|
|
static int64_t calcu_partial_output_buffer_size(const int64_t q_tile_size,
|
|
const int64_t head_dim,
|
|
const int64_t elem_size) {
|
|
return round_to_64(q_tile_size * head_dim * elem_size);
|
|
}
|
|
|
|
static int64_t calcu_partial_output_max_sum_buffer_size(
|
|
const int64_t q_tile_size) {
|
|
return round_to_64(q_tile_size * sizeof(float));
|
|
}
|
|
|
|
static int64_t calcu_reduce_flag_buffer_size(const int64_t total_split_num) {
|
|
return round_to_64(total_split_num * sizeof(bool));
|
|
}
|
|
|
|
static int64_t calcu_reduce_max_sum_buffer_size(
|
|
const int64_t total_split_num, const int32_t q_head_tile_size) {
|
|
return round_to_64(total_split_num * q_head_tile_size * sizeof(float));
|
|
}
|
|
|
|
static int64_t calcu_reduce_output_buffer_size(
|
|
const int64_t total_split_num, const int64_t q_head_tile_size,
|
|
const int64_t head_dim, const int64_t output_buffer_elem_size) {
|
|
return round_to_64(total_split_num * q_head_tile_size * head_dim *
|
|
output_buffer_elem_size);
|
|
}
|
|
|
|
private:
|
|
int8_t* thread_scratchpad_ptr;
|
|
int8_t* reduction_scratchpad_ptr;
|
|
int64_t reduction_scratchpad_size_per_kv_head;
|
|
// attention buffers
|
|
int64_t q_buffer_offset_;
|
|
int64_t logits_buffer_offset_;
|
|
int64_t output_buffer_offset_;
|
|
int64_t max_buffer_offset_;
|
|
int64_t sum_buffer_offset_;
|
|
// reduction buffers
|
|
int64_t reduce_flag_buffer_offset_;
|
|
int64_t reduce_output_buffer_offset_;
|
|
int64_t reduce_max_buffer_offset_;
|
|
int64_t reduce_sum_buffer_offset_;
|
|
};
|
|
|
|
class AttentionScheduler {
|
|
public:
|
|
struct ScheduleInput {
|
|
int32_t num_reqs;
|
|
int32_t elem_size;
|
|
int32_t q_buffer_elem_size;
|
|
int32_t logits_buffer_elem_size;
|
|
int32_t output_buffer_elem_size;
|
|
int32_t num_heads_q;
|
|
int32_t num_heads_kv;
|
|
int32_t head_dim;
|
|
int32_t* query_start_loc;
|
|
int32_t* seq_lens;
|
|
int32_t left_sliding_window_size;
|
|
int32_t right_sliding_window_size;
|
|
bool casual;
|
|
cpu_attention::ISA isa;
|
|
int32_t max_num_q_per_iter; // max Q head num can be hold in registers
|
|
int32_t kv_block_alignment; // context length alignment requirement
|
|
bool enable_kv_split;
|
|
};
|
|
|
|
static constexpr int32_t MaxQTileIterNum = 128;
|
|
|
|
AttentionScheduler() : available_cache_size_(get_available_l2_size()) {}
|
|
|
|
torch::Tensor schedule(const ScheduleInput& input) const {
|
|
const bool casual = input.casual;
|
|
const int32_t thread_num = omp_get_max_threads();
|
|
const int64_t cache_size = get_available_l2_size();
|
|
const int32_t max_num_q_per_iter = input.max_num_q_per_iter;
|
|
const int32_t kv_len_alignment = input.kv_block_alignment;
|
|
int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv;
|
|
const bool use_gqa = (max_num_q_per_iter % q_head_per_kv == 0);
|
|
if (!use_gqa) {
|
|
q_head_per_kv = 1; // fallback to MHA
|
|
}
|
|
const int32_t min_split_kv_len =
|
|
((max_num_q_per_iter * 4 + kv_len_alignment - 1) / kv_len_alignment) *
|
|
kv_len_alignment;
|
|
const int32_t max_num_q_token_per_iter = max_num_q_per_iter / q_head_per_kv;
|
|
const int64_t default_tile_size = calcu_default_tile_size(
|
|
cache_size, input.head_dim, input.elem_size, input.q_buffer_elem_size,
|
|
input.logits_buffer_elem_size, input.output_buffer_elem_size,
|
|
max_num_q_per_iter, max_num_q_per_iter);
|
|
const int32_t default_tile_token_num = default_tile_size / q_head_per_kv;
|
|
const int32_t split_kv_q_token_num_threshold =
|
|
input.enable_kv_split ? 1 : 0;
|
|
const int32_t left_sliding_window_size = input.left_sliding_window_size;
|
|
const int32_t right_sliding_window_size = input.right_sliding_window_size;
|
|
TORCH_CHECK_LE(split_kv_q_token_num_threshold * q_head_per_kv, 16);
|
|
|
|
// get total kv len
|
|
int64_t total_kv_len = 0;
|
|
for (int32_t req_id = 0; req_id < input.num_reqs; ++req_id) {
|
|
const int32_t seq_len = input.seq_lens[req_id];
|
|
const int32_t q_token_num =
|
|
input.query_start_loc[req_id + 1] - input.query_start_loc[req_id];
|
|
const int32_t q_start_pos = (casual ? (seq_len - q_token_num) : 0);
|
|
const int32_t kv_start_pos = 0;
|
|
const int32_t kv_end_pos = seq_len;
|
|
|
|
for (int32_t token_id = 0; token_id < q_token_num;
|
|
token_id += max_num_q_token_per_iter) {
|
|
const int32_t q_tile_token_num =
|
|
std::min(max_num_q_token_per_iter, q_token_num - token_id);
|
|
const int32_t q_tile_pos_left = q_start_pos + token_id;
|
|
const int32_t q_tile_pos_right = q_tile_pos_left + q_tile_token_num;
|
|
const auto [kv_tile_pos_left, kv_tile_pos_right] = calcu_kv_tile_pos(
|
|
kv_start_pos, kv_end_pos, q_tile_pos_left, q_tile_pos_right,
|
|
left_sliding_window_size, right_sliding_window_size);
|
|
const auto [aligned_kv_tile_pos_left, aligned_kv_tile_pos_right] =
|
|
align_kv_tile_pos(kv_tile_pos_left, kv_tile_pos_right,
|
|
kv_len_alignment);
|
|
|
|
int32_t curr_kv_len =
|
|
aligned_kv_tile_pos_right - aligned_kv_tile_pos_left;
|
|
total_kv_len += curr_kv_len;
|
|
}
|
|
}
|
|
const int64_t kv_len_per_thread =
|
|
(((total_kv_len / thread_num) + kv_len_alignment - 1) /
|
|
kv_len_alignment) *
|
|
kv_len_alignment * (use_gqa ? input.num_heads_kv : input.num_heads_q);
|
|
std::vector<AttentionWorkItemGroup> workitems;
|
|
std::vector<ReductionWorkItemGroup> reduce_workitems;
|
|
workitems.reserve(1024);
|
|
reduce_workitems.reserve(1024);
|
|
std::vector<int32_t> workitem_num_per_thread(thread_num, 0);
|
|
|
|
// split tasks
|
|
int32_t curr_thread_id = 0;
|
|
int64_t remaining_kv_len = kv_len_per_thread;
|
|
int32_t cum_split_num = 0;
|
|
for (int32_t req_id = 0; req_id < input.num_reqs; ++req_id) {
|
|
const int32_t seq_len = input.seq_lens[req_id];
|
|
const int32_t q_token_num =
|
|
input.query_start_loc[req_id + 1] - input.query_start_loc[req_id];
|
|
const int32_t q_start_pos = (casual ? (seq_len - q_token_num) : 0);
|
|
const int32_t kv_start_pos = 0;
|
|
const int32_t kv_end_pos = seq_len;
|
|
int32_t local_split_id = 0;
|
|
|
|
AttentionWorkItemGroup curr_workitem(req_id, 0, 0, seq_len);
|
|
for (int32_t token_id = 0; token_id < q_token_num;
|
|
token_id += max_num_q_token_per_iter) {
|
|
const int32_t q_tile_token_num =
|
|
std::min(max_num_q_token_per_iter, q_token_num - token_id);
|
|
const int32_t q_tile_pos_left = q_start_pos + token_id;
|
|
const int32_t q_tile_pos_right = q_tile_pos_left + q_tile_token_num;
|
|
const auto [kv_tile_pos_left, kv_tile_pos_right] = calcu_kv_tile_pos(
|
|
kv_start_pos, kv_end_pos, q_tile_pos_left, q_tile_pos_right,
|
|
left_sliding_window_size, right_sliding_window_size);
|
|
const auto [aligned_kv_tile_pos_left, aligned_kv_tile_pos_right] =
|
|
align_kv_tile_pos(kv_tile_pos_left, kv_tile_pos_right,
|
|
kv_len_alignment);
|
|
int32_t curr_kv_len =
|
|
aligned_kv_tile_pos_right - aligned_kv_tile_pos_left;
|
|
int32_t kv_token_pos_start = aligned_kv_tile_pos_left;
|
|
|
|
while (curr_kv_len > 0) {
|
|
if (curr_kv_len <= (remaining_kv_len + min_split_kv_len) ||
|
|
curr_thread_id == (thread_num - 1)) {
|
|
curr_workitem.q_token_num += q_tile_token_num;
|
|
curr_workitem.total_kv_len += curr_kv_len;
|
|
remaining_kv_len -= curr_kv_len;
|
|
curr_kv_len = 0;
|
|
|
|
if (remaining_kv_len < 0) {
|
|
// stop to accept more workitems
|
|
remaining_kv_len -= min_split_kv_len;
|
|
}
|
|
|
|
if (curr_workitem.kv_split_pos_start != 0) {
|
|
// got a partial kv spilt, need to create a single workitem
|
|
curr_workitem.split_id = cum_split_num;
|
|
curr_workitem.local_split_id = local_split_id;
|
|
workitems.emplace_back(curr_workitem);
|
|
++workitem_num_per_thread[curr_thread_id];
|
|
++reduce_workitems.back().split_num;
|
|
++cum_split_num;
|
|
|
|
curr_workitem = AttentionWorkItemGroup(
|
|
req_id, token_id + max_num_q_token_per_iter, 0, seq_len);
|
|
}
|
|
|
|
break;
|
|
}
|
|
|
|
if (remaining_kv_len < min_split_kv_len &&
|
|
(curr_workitem.total_kv_len > 0 ||
|
|
workitem_num_per_thread[curr_thread_id] > 0)) {
|
|
// remaining_kv_len is too short, and have allocated workitems, just
|
|
// leave to next thread
|
|
if (curr_workitem.total_kv_len > 0) {
|
|
workitems.emplace_back(curr_workitem);
|
|
++workitem_num_per_thread[curr_thread_id];
|
|
curr_workitem =
|
|
AttentionWorkItemGroup(req_id, token_id, 0, seq_len);
|
|
}
|
|
|
|
// switch to next thread
|
|
++curr_thread_id;
|
|
remaining_kv_len = kv_len_per_thread;
|
|
|
|
// retry this iteration
|
|
continue;
|
|
}
|
|
|
|
// only split tail splits with q_tile_token_num <=
|
|
// split_kv_q_token_num_threshold
|
|
if (token_id + max_num_q_token_per_iter < q_token_num ||
|
|
q_tile_token_num > split_kv_q_token_num_threshold) {
|
|
// if requires a new q tile iteration and already has workitems,
|
|
// leave this workitem to next thread
|
|
if (curr_workitem.q_token_num % default_tile_token_num == 0 &&
|
|
(curr_workitem.total_kv_len > 0 ||
|
|
workitem_num_per_thread[curr_thread_id] > 0)) {
|
|
if (curr_workitem.total_kv_len > 0) {
|
|
workitems.emplace_back(curr_workitem);
|
|
++workitem_num_per_thread[curr_thread_id];
|
|
}
|
|
curr_workitem =
|
|
AttentionWorkItemGroup(req_id, token_id, 0, seq_len);
|
|
|
|
// switch to next thread
|
|
++curr_thread_id;
|
|
remaining_kv_len = kv_len_per_thread;
|
|
}
|
|
|
|
curr_workitem.q_token_num += q_tile_token_num;
|
|
curr_workitem.total_kv_len += curr_kv_len;
|
|
remaining_kv_len -= curr_kv_len;
|
|
curr_kv_len = 0;
|
|
break;
|
|
}
|
|
|
|
// split kv
|
|
if (curr_workitem.total_kv_len > 0) {
|
|
// write back curr workitem
|
|
workitems.emplace_back(curr_workitem);
|
|
++workitem_num_per_thread[curr_thread_id];
|
|
}
|
|
|
|
if (kv_token_pos_start == aligned_kv_tile_pos_left) {
|
|
// first split, init the workitem
|
|
reduce_workitems.emplace_back(ReductionWorkItemGroup(
|
|
req_id, token_id, q_tile_token_num, cum_split_num));
|
|
}
|
|
|
|
int32_t spilt_size =
|
|
std::min(std::max(remaining_kv_len, (int64_t)min_split_kv_len),
|
|
(int64_t)curr_kv_len);
|
|
curr_workitem =
|
|
AttentionWorkItemGroup(req_id, token_id, kv_token_pos_start,
|
|
kv_token_pos_start + spilt_size);
|
|
curr_workitem.q_token_num += q_tile_token_num;
|
|
curr_workitem.total_kv_len += spilt_size;
|
|
curr_workitem.split_id = cum_split_num;
|
|
curr_workitem.local_split_id = local_split_id;
|
|
workitems.emplace_back(curr_workitem);
|
|
++workitem_num_per_thread[curr_thread_id];
|
|
++reduce_workitems.back().split_num;
|
|
++cum_split_num;
|
|
++local_split_id;
|
|
|
|
kv_token_pos_start += spilt_size;
|
|
curr_kv_len -= spilt_size;
|
|
curr_workitem = AttentionWorkItemGroup(req_id, token_id,
|
|
kv_token_pos_start, seq_len);
|
|
|
|
// switch to next thread
|
|
++curr_thread_id;
|
|
remaining_kv_len = kv_len_per_thread;
|
|
}
|
|
}
|
|
|
|
if (curr_workitem.total_kv_len > 0) {
|
|
// write back curr workitem
|
|
workitems.emplace_back(curr_workitem);
|
|
++workitem_num_per_thread[curr_thread_id];
|
|
}
|
|
}
|
|
|
|
int64_t metadata_tensor_size =
|
|
sizeof(AttentionMetadata) +
|
|
workitems.size() * sizeof(AttentionWorkItemGroup) +
|
|
reduce_workitems.size() * sizeof(ReductionWorkItemGroup);
|
|
auto options =
|
|
torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
|
|
torch::Tensor metadata_tensor =
|
|
torch::empty({metadata_tensor_size}, options);
|
|
AttentionMetadata* metadata_ptr = new (metadata_tensor.data_ptr())
|
|
AttentionMetadata(input.isa, workitems.size(), reduce_workitems.size(),
|
|
cum_split_num, split_kv_q_token_num_threshold);
|
|
AttentionWorkItemGroup* workitem_groups_ptr =
|
|
metadata_ptr->workitem_groups_ptr;
|
|
ReductionWorkItemGroup* reduction_items_ptr =
|
|
metadata_ptr->reduction_items_ptr;
|
|
std::memcpy(workitem_groups_ptr, workitems.data(),
|
|
workitems.size() * sizeof(AttentionWorkItemGroup));
|
|
std::memcpy(reduction_items_ptr, reduce_workitems.data(),
|
|
reduce_workitems.size() * sizeof(ReductionWorkItemGroup));
|
|
|
|
int32_t effective_thread_num = 0;
|
|
for (; effective_thread_num < thread_num; ++effective_thread_num) {
|
|
if (workitem_num_per_thread[effective_thread_num] == 0) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
std::memcpy(metadata_ptr->cu_workitem_num_per_thread + 1,
|
|
workitem_num_per_thread.data(),
|
|
workitem_num_per_thread.size() * sizeof(int32_t));
|
|
for (int32_t i = 1; i <= thread_num; ++i) {
|
|
metadata_ptr->cu_workitem_num_per_thread[i] +=
|
|
metadata_ptr->cu_workitem_num_per_thread[i - 1];
|
|
}
|
|
metadata_ptr->effective_thread_num = effective_thread_num;
|
|
|
|
{
|
|
// when q_tile_size = max_num_q_per_iter, requires max
|
|
// attention_scratchpad_size
|
|
AttentionScratchPad sc(0, *metadata_ptr, 0x0);
|
|
int64_t n = AttentionScheduler::calcu_tile_size_with_constant_q(
|
|
cache_size, input.head_dim, input.elem_size, input.q_buffer_elem_size,
|
|
input.logits_buffer_elem_size, input.output_buffer_elem_size,
|
|
max_num_q_per_iter, kv_len_alignment, max_num_q_per_iter, true);
|
|
sc.update(input.head_dim, input.q_buffer_elem_size,
|
|
input.logits_buffer_elem_size, input.output_buffer_elem_size,
|
|
max_num_q_per_iter, max_num_q_per_iter, n);
|
|
metadata_ptr->attention_scratchpad_size_per_thread =
|
|
((sc.get_thread_scratchpad_size() + 63) / 64) * 64;
|
|
|
|
sc.update(0, metadata_ptr->reduction_split_num, input.head_dim,
|
|
q_head_per_kv * split_kv_q_token_num_threshold,
|
|
input.output_buffer_elem_size);
|
|
metadata_ptr->reduction_scratchpad_size_per_kv_head =
|
|
((sc.get_reduction_scratchpad_size() + 63) / 64) * 64;
|
|
}
|
|
int64_t scratchpad_size =
|
|
metadata_ptr->attention_scratchpad_size_per_thread *
|
|
metadata_ptr->thread_num +
|
|
metadata_ptr->reduction_scratchpad_size_per_kv_head *
|
|
(use_gqa ? input.num_heads_kv : input.num_heads_q);
|
|
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(
|
|
scratchpad_size);
|
|
|
|
// metadata_ptr->print();
|
|
|
|
// test out of boundary access
|
|
// {
|
|
// float* cache_ptr =
|
|
// DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<float>();
|
|
// for (int64_t i = 0; i < scratchpad_size / sizeof(float); ++i) {
|
|
// cache_ptr[i] = std::numeric_limits<float>::quiet_NaN();
|
|
// }
|
|
// }
|
|
|
|
return metadata_tensor;
|
|
}
|
|
|
|
FORCE_INLINE static std::pair<int32_t, int32_t> calcu_kv_tile_pos(
|
|
int32_t kv_left_pos, int32_t kv_right_pos, int32_t q_left_pos,
|
|
int32_t q_right_pos, int32_t sliding_window_left,
|
|
int32_t sliding_window_right) {
|
|
if (sliding_window_left != -1) {
|
|
kv_left_pos = std::max(kv_left_pos, q_left_pos - sliding_window_left);
|
|
}
|
|
if (sliding_window_right != -1) {
|
|
kv_right_pos = std::min(kv_right_pos, q_right_pos + sliding_window_right);
|
|
}
|
|
return {kv_left_pos, kv_right_pos};
|
|
}
|
|
|
|
FORCE_INLINE static std::pair<int32_t, int32_t> align_kv_tile_pos(
|
|
int32_t kv_left_pos, int32_t kv_right_pos, int32_t align_factor) {
|
|
kv_left_pos = (kv_left_pos / align_factor) * align_factor;
|
|
kv_right_pos =
|
|
((kv_right_pos + align_factor - 1) / align_factor) * align_factor;
|
|
return {kv_left_pos, kv_right_pos};
|
|
}
|
|
|
|
static int64_t calcu_default_tile_size(int64_t cache_size, int64_t head_dim,
|
|
int64_t elem_size,
|
|
int64_t q_buffer_elem_size,
|
|
int64_t logits_buffer_elem_size,
|
|
int64_t output_buffer_elem_size,
|
|
int64_t max_num_q_per_iter,
|
|
int64_t round_size) {
|
|
// For CPU, different from CUDA, Q@K^T results should also be hold in cache,
|
|
// using float32. Intermediate outputs should be float32 to be compatible
|
|
// with AMX Then the cache includes:
|
|
// - Q: q_tile_size * head_dim * q_buffer_elem_size
|
|
// - K, V: 2 * k_tile_size * head_dim * elem_size
|
|
// - Q@K^T: max_num_q_per_iter * k_tile_size * logits_buffer_elem_size
|
|
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size
|
|
// By default, let tile_size = q_tile_size = k_tile_size. To record
|
|
// is_first_iter states in a static array, require the default tile <= 128 *
|
|
// max_num_q_per_iter
|
|
|
|
int64_t tile_size =
|
|
cache_size / (head_dim * (q_buffer_elem_size + 2 * elem_size +
|
|
output_buffer_elem_size) +
|
|
max_num_q_per_iter * logits_buffer_elem_size);
|
|
tile_size = std::min(tile_size, MaxQTileIterNum * max_num_q_per_iter);
|
|
int64_t rounded_tile_size = (tile_size / round_size) * round_size;
|
|
return std::max(rounded_tile_size, round_size);
|
|
}
|
|
|
|
static int64_t calcu_tile_size_with_constant_q(
|
|
int64_t cache_size, int64_t head_dim, int64_t elem_size,
|
|
int64_t q_buffer_elem_size, int64_t logits_buffer_elem_size,
|
|
int64_t output_buffer_elem_size, int64_t max_num_q_per_iter,
|
|
int64_t round_size, int64_t q_tile_size, bool one_round) {
|
|
// calculate tile_size with known q_tile_size
|
|
// If one_round is True, the outer Q tile loop time is 1, then the K,V will
|
|
// not be included in the cache
|
|
int64_t tile_size;
|
|
if (one_round) {
|
|
tile_size =
|
|
(cache_size - q_tile_size * head_dim *
|
|
(q_buffer_elem_size + output_buffer_elem_size)) /
|
|
(logits_buffer_elem_size * max_num_q_per_iter);
|
|
} else {
|
|
tile_size =
|
|
(cache_size - q_tile_size * head_dim *
|
|
(q_buffer_elem_size + output_buffer_elem_size)) /
|
|
(logits_buffer_elem_size * max_num_q_per_iter +
|
|
2 * head_dim * elem_size);
|
|
}
|
|
int64_t rounded_tile_size = (tile_size / round_size) * round_size;
|
|
return std::max(rounded_tile_size, round_size);
|
|
}
|
|
|
|
static int64_t get_available_l2_size() {
|
|
static int64_t size = []() {
|
|
#if defined(__APPLE__)
|
|
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
|
|
int64_t l2_cache_size = 0;
|
|
size_t len = sizeof(l2_cache_size);
|
|
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
|
|
l2_cache_size > 0) {
|
|
return l2_cache_size >> 1; // use 50% of L2 cache
|
|
}
|
|
// Fallback if sysctlbyname fails
|
|
return 128LL * 1024 >> 1; // use 50% of 128KB
|
|
#else
|
|
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
|
|
TORCH_CHECK_NE(l2_cache_size, -1);
|
|
return l2_cache_size >> 1; // use 50% of L2 cache
|
|
#endif
|
|
}();
|
|
return size;
|
|
}
|
|
|
|
private:
|
|
int64_t available_cache_size_;
|
|
};
|
|
|
|
struct AttentionInput {
|
|
AttentionMetadata* metadata;
|
|
int32_t num_tokens;
|
|
int32_t num_heads;
|
|
int32_t num_kv_heads;
|
|
int32_t block_size;
|
|
void* query;
|
|
int64_t query_num_tokens_stride;
|
|
int64_t query_num_heads_stride;
|
|
int64_t cache_num_blocks_stride;
|
|
int64_t cache_num_kv_heads_stride;
|
|
int64_t blt_num_tokens_stride;
|
|
void* key_cache;
|
|
void* value_cache;
|
|
void* output;
|
|
int32_t* query_start_loc;
|
|
int32_t* seq_lens;
|
|
int32_t* block_table;
|
|
float* alibi_slopes;
|
|
c10::BFloat16* s_aux;
|
|
float scale;
|
|
bool causal;
|
|
int32_t sliding_window_left;
|
|
int32_t sliding_window_right;
|
|
float softcap;
|
|
};
|
|
|
|
#define DEFINE_CPU_ATTENTION_PARAMS \
|
|
q_buffer_t *__restrict__ q_heads_buffer, \
|
|
kv_cache_t *__restrict__ k_head_cache_ptr, \
|
|
kv_cache_t *__restrict__ v_head_cache_ptr, \
|
|
logits_buffer_t *__restrict__ logits_buffer, \
|
|
float *__restrict__ partial_q_buffer, float *__restrict__ max_buffer, \
|
|
float *__restrict__ sum_buffer, int32_t *__restrict__ block_table, \
|
|
const int32_t kv_tile_start_pos, const int32_t kv_tile_end_pos, \
|
|
const int32_t kv_tile_token_num, \
|
|
const int64_t kv_cache_num_blocks_stride, const int32_t q_head_num, \
|
|
const int32_t q_token_num, const int32_t q_tile_start_pos, \
|
|
const int32_t q_heads_per_kv, const int32_t block_size, \
|
|
const int32_t left_window_size, const int32_t right_window_size, \
|
|
float scale, const float softcap_scale, \
|
|
const float *__restrict__ alibi_slopes, const bool is_first_iter, \
|
|
const bool use_sink, const bool debug_info
|
|
|
|
#define CPU_ATTENTION_PARAMS \
|
|
q_heads_buffer, k_head_cache_ptr, v_head_cache_ptr, logits_buffer, \
|
|
partial_q_buffer, max_buffer, sum_buffer, block_table, \
|
|
kv_tile_start_pos, kv_tile_end_pos, kv_tile_token_num, \
|
|
kv_cache_num_blocks_stride, q_head_num, q_token_num, q_tile_start_pos, \
|
|
q_heads_per_kv, block_size, left_window_size, right_window_size, scale, \
|
|
softcap_scale, alibi_slopes, is_first_iter, use_sink, debug_info
|
|
|
|
enum class AttentionGemmPhase { QK, PV };
|
|
|
|
template <typename T>
|
|
struct VecTypeTrait {
|
|
using vec_t = void;
|
|
};
|
|
|
|
template <>
|
|
struct VecTypeTrait<float> {
|
|
using vec_t = vec_op::FP32Vec16;
|
|
};
|
|
|
|
// ARM only supports BF16 with ARMv8.6-A extension
|
|
#if (defined(__aarch64__) && !defined(ARM_BF16_SUPPORT))
|
|
#else
|
|
template <>
|
|
struct VecTypeTrait<c10::BFloat16> {
|
|
using vec_t = vec_op::BF16Vec16;
|
|
};
|
|
#endif
|
|
|
|
#if !defined(__powerpc__) && !defined(__s390x__)
|
|
template <>
|
|
struct VecTypeTrait<c10::Half> {
|
|
using vec_t = vec_op::FP16Vec16;
|
|
};
|
|
#endif
|
|
|
|
template <typename T>
|
|
void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
|
|
int32_t stride) {
|
|
std::stringstream ss;
|
|
ss << std::fixed << std::setprecision(5) << name << ": [\n";
|
|
auto* curr_logits_buffer = ptr;
|
|
for (int32_t m = 0; m < row; ++m) {
|
|
for (int32_t n = 0; n < col; ++n) {
|
|
ss << curr_logits_buffer[n] << ", ";
|
|
}
|
|
ss << "\n";
|
|
curr_logits_buffer += stride;
|
|
}
|
|
ss << "]\n";
|
|
std::printf("%s", ss.str().c_str());
|
|
}
|
|
|
|
template <typename attention_impl_t>
|
|
class AttentionMainLoop {
|
|
public:
|
|
using query_t = typename attention_impl_t::query_t;
|
|
using q_buffer_t = typename attention_impl_t::q_buffer_t;
|
|
using kv_cache_t = typename attention_impl_t::kv_cache_t;
|
|
using logits_buffer_t = typename attention_impl_t::logits_buffer_t;
|
|
using partial_output_buffer_t =
|
|
typename attention_impl_t::partial_output_buffer_t;
|
|
using prob_buffer_t = typename attention_impl_t::prob_buffer_t;
|
|
|
|
static constexpr int64_t max_q_head_num_per_iter =
|
|
attention_impl_t::MaxQHeadNumPerIteration;
|
|
static constexpr int64_t blocksize_alignment =
|
|
attention_impl_t::BlockSizeAlignment;
|
|
static constexpr int64_t headdim_alignment =
|
|
attention_impl_t::HeadDimAlignment;
|
|
static constexpr int64_t head_dim = attention_impl_t::HeadDim;
|
|
static constexpr ISA ISAType = attention_impl_t::ISAType;
|
|
static constexpr bool scale_on_logits =
|
|
attention_impl_t::scale_on_logits; // apply scale on logits, otherwise
|
|
// apply scale on q_buffer
|
|
|
|
template <typename tile_gemm_t>
|
|
class Attention {
|
|
public:
|
|
// Args:
|
|
// - q_heads_buffer: [MaxQHeadNumPerIteration, head_dim]
|
|
// - k_head_cache_ptr: [num_blocks, block_size * head_dim]
|
|
// - v_head_cache_ptr: [num_blocks, block_size * head_dim]
|
|
// - logits_buffer: [MaxQHeadNumPerIteration, kv_tile_token_num], store Q@K
|
|
// - logits partial_q_buffer: [MaxQHeadNumPerIteration, head_dim], store
|
|
// partial output
|
|
// - max_buffer: [MaxQHeadNumPerIteration, 1], store max logits
|
|
// - sum_buffer: [MaxQHeadNumPerIteration, 1], store sum of exp
|
|
// - block_table
|
|
// - kv_tile_start_pos: start position of KV cache, aligned to
|
|
// BlockSizeAlignment
|
|
// - kv_tile_end_pos: end position of KV cache, aligned to
|
|
// BlockSizeAlignment
|
|
// - kv_tile_token_num: KV token num, aligned to BlockSizeAlignment
|
|
// - kv_cache_num_blocks_stride
|
|
// - q_head_num: head num of q_tile
|
|
// - q_token_num: token num of q_tile, should be q_head_num /
|
|
// q_heads_per_kv
|
|
// - q_tile_start_pos: start pos of the first token in q_heads_buffer
|
|
// - q_heads_per_kv
|
|
// - block_size
|
|
// - left_window_size
|
|
// - right_window_size
|
|
// - scale
|
|
// - softcap_scale
|
|
// - alibi_slopes
|
|
// - is_first_iter
|
|
// - use_sink
|
|
// - debug_info
|
|
void operator()(DEFINE_CPU_ATTENTION_PARAMS) {
|
|
// k_cache_token_group_stride: stride of K cache when move to next
|
|
// BlockSizeAlignment tokens in a block
|
|
const int64_t k_cache_token_group_stride =
|
|
attention_impl_t::k_cache_token_group_stride(block_size);
|
|
// v_cache_token_group_stride: stride of V cache when move to next
|
|
// BlockSizeAlignment tokens in a block
|
|
const int64_t v_cache_token_group_stride =
|
|
attention_impl_t::v_cache_token_group_stride(block_size);
|
|
// v_cache_head_group_stride: stride of V cache when move to next
|
|
// HeadDimAlignment head dims in a block
|
|
const int64_t v_cache_head_group_stride =
|
|
attention_impl_t::v_cache_head_group_stride(block_size);
|
|
const int32_t token_group_num = kv_tile_token_num / blocksize_alignment;
|
|
const int32_t token_group_num_per_block =
|
|
block_size / blocksize_alignment;
|
|
const int32_t start_block_idx = kv_tile_start_pos / block_size;
|
|
const int32_t start_block_offset = kv_tile_start_pos % block_size;
|
|
const int32_t start_block_group_offset =
|
|
start_block_offset / blocksize_alignment;
|
|
const int32_t end_block_idx =
|
|
(kv_tile_start_pos + kv_tile_token_num - 1) / block_size + 1;
|
|
|
|
// compute Q@K logits
|
|
{
|
|
int32_t curr_group_offset =
|
|
start_block_group_offset * k_cache_token_group_stride;
|
|
int32_t curr_group_num_in_block =
|
|
token_group_num_per_block - start_block_group_offset;
|
|
int32_t remaining_group_num = token_group_num;
|
|
logits_buffer_t* curr_logits_buffer = logits_buffer;
|
|
for (int32_t block_idx = start_block_idx; block_idx < end_block_idx;
|
|
++block_idx) {
|
|
int32_t physical_block_idx = block_table[block_idx];
|
|
kv_cache_t* k_cache_block_ptr =
|
|
k_head_cache_ptr +
|
|
physical_block_idx * kv_cache_num_blocks_stride +
|
|
curr_group_offset;
|
|
curr_group_num_in_block =
|
|
std::min(remaining_group_num, curr_group_num_in_block);
|
|
|
|
for (int32_t block_group_idx = 0;
|
|
block_group_idx < curr_group_num_in_block; ++block_group_idx) {
|
|
// logits_tile = q_tile @ k_tile, [MaxQHeadNumPerIteration,
|
|
// BlockSizeAlignment] = [MaxQHeadNumPerIteration, head_dim] @
|
|
// [head_dim, BlockSizeAlignment]
|
|
|
|
// By default, logits_buffer, q_buffer and k_cache are row-major,
|
|
// but may be packed by ISA implementation.
|
|
tile_gemm_t::template gemm<AttentionGemmPhase::QK, head_dim>(
|
|
q_head_num, q_heads_buffer, k_cache_block_ptr,
|
|
curr_logits_buffer, head_dim, block_size, kv_tile_token_num,
|
|
block_size, head_dim, false);
|
|
|
|
if constexpr (scale_on_logits) {
|
|
float* __restrict__ scale_curr_logits_buffer = curr_logits_buffer;
|
|
vec_op::FP32Vec16 scale_vec(scale);
|
|
for (int32_t i = 0; i < q_head_num; ++i) {
|
|
static_assert(blocksize_alignment % 16 == 0);
|
|
constexpr int32_t vec_num = blocksize_alignment / 16;
|
|
vec_op::unroll_loop<int32_t, vec_num>([&](int32_t vec_idx) {
|
|
vec_op::FP32Vec16 vec(scale_curr_logits_buffer +
|
|
vec_idx * 16);
|
|
vec = vec * scale_vec;
|
|
vec.save(scale_curr_logits_buffer + vec_idx * 16);
|
|
});
|
|
scale_curr_logits_buffer += kv_tile_token_num;
|
|
}
|
|
}
|
|
|
|
// Move buffer ptrs
|
|
k_cache_block_ptr += k_cache_token_group_stride;
|
|
curr_logits_buffer += blocksize_alignment;
|
|
}
|
|
|
|
// Update
|
|
remaining_group_num -= curr_group_num_in_block;
|
|
curr_group_offset = 0;
|
|
curr_group_num_in_block = token_group_num_per_block;
|
|
}
|
|
}
|
|
|
|
// process logits
|
|
{
|
|
// if (debug_info){
|
|
// print_logits("raw logits", logits_buffer, q_head_num,
|
|
// kv_tile_token_num, kv_tile_token_num);
|
|
// }
|
|
|
|
if (softcap_scale != 0.0f) {
|
|
apply_softcap(logits_buffer, kv_tile_token_num, q_head_num,
|
|
kv_tile_token_num, softcap_scale);
|
|
// print_logits("softcap raw logits", logits_buffer, q_head_num,
|
|
// kv_tile_token_num, kv_tile_token_num);
|
|
}
|
|
|
|
if (alibi_slopes != nullptr) {
|
|
apply_alibi_slopes(logits_buffer, alibi_slopes, kv_tile_token_num,
|
|
q_tile_start_pos, kv_tile_start_pos, q_token_num,
|
|
kv_tile_token_num, q_heads_per_kv);
|
|
|
|
// print_logits("alibi raw logits", logits_buffer, q_head_num,
|
|
// kv_tile_token_num, kv_tile_token_num);
|
|
}
|
|
|
|
apply_mask(logits_buffer, kv_tile_token_num, q_tile_start_pos,
|
|
kv_tile_start_pos, kv_tile_end_pos, q_token_num,
|
|
q_heads_per_kv, left_window_size, right_window_size);
|
|
|
|
// if (debug_info){
|
|
// print_logits("masked logits", logits_buffer, q_head_num,
|
|
// kv_tile_token_num, kv_tile_token_num);
|
|
// print_logits("old_max", max_buffer, 1, q_head_num, q_head_num);
|
|
// print_logits("old_sum", sum_buffer, 1, q_head_num, q_head_num);
|
|
// }
|
|
|
|
apply_softmax(logits_buffer, partial_q_buffer, max_buffer, sum_buffer,
|
|
kv_tile_token_num, q_head_num, kv_tile_token_num,
|
|
is_first_iter, use_sink);
|
|
|
|
// if (debug_info){
|
|
// print_logits("softmax logits",
|
|
// reinterpret_cast<prob_buffer_t*>(logits_buffer), q_head_num,
|
|
// kv_tile_token_num, kv_tile_token_num * sizeof(logits_buffer_t) /
|
|
// sizeof(prob_buffer_t));
|
|
// print_logits("new_max", max_buffer, 1, q_head_num, q_head_num);
|
|
// print_logits("new_sum", sum_buffer, 1, q_head_num, q_head_num);
|
|
// }
|
|
}
|
|
|
|
// compute P@V
|
|
{
|
|
int32_t curr_group_offset =
|
|
start_block_group_offset * v_cache_token_group_stride;
|
|
int32_t curr_group_num_in_block =
|
|
token_group_num_per_block - start_block_group_offset;
|
|
int32_t remaining_group_num = token_group_num;
|
|
int32_t head_dim_group_num = head_dim / headdim_alignment;
|
|
prob_buffer_t* curr_prob_buffer =
|
|
reinterpret_cast<prob_buffer_t*>(logits_buffer);
|
|
int64_t prob_buffer_stride =
|
|
kv_tile_token_num *
|
|
(sizeof(logits_buffer_t) / sizeof(prob_buffer_t));
|
|
partial_output_buffer_t* curr_partial_q_buffer = partial_q_buffer;
|
|
bool accum_c = !is_first_iter;
|
|
for (int32_t block_idx = start_block_idx; block_idx < end_block_idx;
|
|
++block_idx) {
|
|
int32_t physical_block_idx = block_table[block_idx];
|
|
kv_cache_t* v_cache_block_ptr =
|
|
v_head_cache_ptr +
|
|
physical_block_idx * kv_cache_num_blocks_stride +
|
|
curr_group_offset;
|
|
curr_group_num_in_block =
|
|
std::min(remaining_group_num, curr_group_num_in_block);
|
|
int32_t curr_token_num =
|
|
curr_group_num_in_block * blocksize_alignment;
|
|
|
|
for (int32_t head_dim_group_idx = 0;
|
|
head_dim_group_idx < head_dim_group_num; ++head_dim_group_idx) {
|
|
// output_tile = p_tile @ v_tile, [MaxQHeadNumPerIteration,
|
|
// HeadDimAlignment] = [MaxQHeadNumPerIteration, block_size] @
|
|
// [block_size, HeadDimAlignment]
|
|
tile_gemm_t::template gemm<AttentionGemmPhase::PV, -1>(
|
|
q_head_num, curr_prob_buffer, v_cache_block_ptr,
|
|
curr_partial_q_buffer, prob_buffer_stride, head_dim, head_dim,
|
|
block_size, curr_token_num, accum_c);
|
|
|
|
// Update
|
|
curr_partial_q_buffer += headdim_alignment;
|
|
v_cache_block_ptr += v_cache_head_group_stride;
|
|
}
|
|
|
|
// Update
|
|
remaining_group_num -= curr_group_num_in_block;
|
|
curr_group_offset = 0;
|
|
curr_group_num_in_block = token_group_num_per_block;
|
|
curr_prob_buffer += curr_token_num;
|
|
curr_partial_q_buffer = partial_q_buffer;
|
|
accum_c = true;
|
|
}
|
|
}
|
|
// if (debug_info) {
|
|
// print_logits("output", partial_q_buffer, q_head_num, head_dim,
|
|
// head_dim);
|
|
// }
|
|
}
|
|
|
|
void apply_mask(logits_buffer_t* __restrict__ logits_buffer,
|
|
const int64_t logits_buffer_stride,
|
|
const int32_t q_tile_start_pos,
|
|
const int32_t kv_tile_start_pos,
|
|
const int32_t kv_tile_end_pos, const int32_t q_token_num,
|
|
const int32_t q_heads_per_kv,
|
|
const int32_t sliding_window_left,
|
|
const int32_t sliding_window_right) {
|
|
// Apply mask
|
|
constexpr logits_buffer_t neg_inf =
|
|
-std::numeric_limits<logits_buffer_t>::infinity();
|
|
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
|
|
int32_t curr_token_pos = q_tile_start_pos;
|
|
for (int32_t token_idx = 0; token_idx < q_token_num; ++token_idx) {
|
|
int32_t left_kv_pos = [&]() {
|
|
int32_t pos = kv_tile_start_pos;
|
|
if (sliding_window_left != -1) {
|
|
pos = std::max(pos, curr_token_pos - sliding_window_left);
|
|
}
|
|
return pos;
|
|
}();
|
|
|
|
int32_t right_kv_pos = [&]() {
|
|
int32_t pos = kv_tile_end_pos;
|
|
if (sliding_window_right != -1) {
|
|
pos = std::min(pos,
|
|
std::max(kv_tile_start_pos,
|
|
curr_token_pos + sliding_window_right + 1));
|
|
}
|
|
return pos;
|
|
}();
|
|
|
|
int32_t left_invalid_token_num = left_kv_pos - kv_tile_start_pos;
|
|
int32_t right_invalid_token_num = kv_tile_end_pos - right_kv_pos;
|
|
for (int32_t head_idx = 0; head_idx < q_heads_per_kv; ++head_idx) {
|
|
logits_buffer_t* __restrict__ curr_logits_buffer_tail =
|
|
curr_logits_buffer + right_kv_pos - kv_tile_start_pos;
|
|
for (int32_t i = 0; i < left_invalid_token_num; ++i) {
|
|
curr_logits_buffer[i] = neg_inf;
|
|
}
|
|
for (int32_t i = 0; i < right_invalid_token_num; ++i) {
|
|
curr_logits_buffer_tail[i] = neg_inf;
|
|
}
|
|
|
|
curr_logits_buffer += logits_buffer_stride;
|
|
}
|
|
|
|
++curr_token_pos;
|
|
}
|
|
}
|
|
|
|
void apply_softmax(logits_buffer_t* __restrict__ logits_buffer,
|
|
float* __restrict__ partial_q_buffer,
|
|
float* __restrict__ max_buffer,
|
|
float* __restrict__ sum_buffer,
|
|
const int64_t logits_buffer_stride, int32_t q_head_num,
|
|
int32_t kv_tile_token_num, bool is_first_iter,
|
|
bool use_sink) {
|
|
#ifdef DEFINE_FAST_EXP
|
|
DEFINE_FAST_EXP
|
|
#endif
|
|
using prob_buffer_vec_t = typename VecTypeTrait<prob_buffer_t>::vec_t;
|
|
static_assert(sizeof(prob_buffer_t) <= sizeof(logits_buffer_t));
|
|
|
|
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
|
|
float* __restrict__ curr_partial_q_buffer = partial_q_buffer;
|
|
const int32_t vec_num = kv_tile_token_num / 16;
|
|
const int32_t head_vec_num = head_dim / 16;
|
|
for (int32_t i = 0; i < q_head_num; ++i) {
|
|
float init_max_val = max_buffer[i];
|
|
float init_sum_val = sum_buffer[i];
|
|
|
|
// apply scale and compute max
|
|
vec_op::FP32Vec16 max_vec(init_max_val);
|
|
{
|
|
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
|
|
curr_logits_buffer;
|
|
for (int32_t j = 0; j < vec_num; ++j) {
|
|
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
|
|
max_vec = vec.max(max_vec);
|
|
|
|
curr_logits_buffer_iter += 16;
|
|
}
|
|
}
|
|
float new_max_val = max_vec.reduce_max();
|
|
float rescale_factor = init_max_val - new_max_val;
|
|
|
|
// use same rescale threshold with FA4.
|
|
// https://github.com/Dao-AILab/flash-attention/blob/1b8e1e641c6a179be9a0538b7f40fd595050b735/flash_attn/cute/flash_fwd_sm100.py#L1271
|
|
bool need_rescale = rescale_factor < -8.0;
|
|
if (!need_rescale) {
|
|
new_max_val = init_max_val;
|
|
} else {
|
|
max_buffer[i] = new_max_val;
|
|
}
|
|
|
|
// sub max, compute exp and sum
|
|
max_vec = vec_op::FP32Vec16(new_max_val);
|
|
vec_op::FP32Vec16 sum_vec(0.0);
|
|
{
|
|
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
|
|
curr_logits_buffer;
|
|
prob_buffer_t* __restrict__ curr_prob_buffer_iter =
|
|
reinterpret_cast<prob_buffer_t*>(curr_logits_buffer);
|
|
for (int32_t j = 0; j < vec_num; ++j) {
|
|
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
|
|
vec = vec - max_vec;
|
|
|
|
// compute exp
|
|
#ifdef DEFINE_FAST_EXP
|
|
vec = fast_exp(vec);
|
|
prob_buffer_vec_t output_vec(vec);
|
|
output_vec.save(curr_prob_buffer_iter);
|
|
#else
|
|
vec.save(curr_logits_buffer_iter);
|
|
for (int32_t k = 0; k < 16; ++k) {
|
|
curr_logits_buffer_iter[k] = std::exp(curr_logits_buffer_iter[k]);
|
|
}
|
|
vec = vec_op::FP32Vec16(curr_logits_buffer_iter);
|
|
#endif
|
|
|
|
sum_vec = sum_vec + vec;
|
|
|
|
curr_logits_buffer_iter += 16;
|
|
curr_prob_buffer_iter += 16;
|
|
}
|
|
}
|
|
float new_sum_val = sum_vec.reduce_sum();
|
|
|
|
// rescale sum and partial outputs
|
|
if (need_rescale) {
|
|
// compute rescale factor
|
|
rescale_factor = std::exp(rescale_factor);
|
|
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
|
|
|
|
// rescale sum
|
|
new_sum_val += rescale_factor * init_sum_val;
|
|
|
|
// rescale output
|
|
if (!is_first_iter) {
|
|
float* __restrict__ curr_partial_q_buffer_iter =
|
|
curr_partial_q_buffer;
|
|
for (int32_t j = 0; j < head_vec_num; ++j) {
|
|
vec_op::FP32Vec16 vec(curr_partial_q_buffer_iter);
|
|
vec = vec * rescale_factor_vec;
|
|
vec.save(curr_partial_q_buffer_iter);
|
|
|
|
curr_partial_q_buffer_iter += 16;
|
|
}
|
|
}
|
|
} else {
|
|
new_sum_val += init_sum_val;
|
|
}
|
|
|
|
sum_buffer[i] = new_sum_val;
|
|
|
|
curr_logits_buffer += logits_buffer_stride;
|
|
curr_partial_q_buffer += head_dim;
|
|
}
|
|
}
|
|
|
|
void apply_softcap(logits_buffer_t* __restrict__ logits_buffer,
|
|
const int64_t logits_buffer_stride, int32_t q_head_num,
|
|
int32_t kv_tile_token_num, float softcap_scale) {
|
|
#ifdef DEFINE_FAST_EXP
|
|
DEFINE_FAST_EXP
|
|
#endif
|
|
float inv_softcap_scale = 1.0 / softcap_scale;
|
|
vec_op::FP32Vec16 softcap_scale_vec(softcap_scale);
|
|
vec_op::FP32Vec16 inv_softcap_scale_vec(inv_softcap_scale);
|
|
vec_op::FP32Vec16 ones_vec(1.0);
|
|
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
|
|
const int32_t vec_num = kv_tile_token_num / 16;
|
|
for (int32_t i = 0; i < q_head_num; ++i) {
|
|
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
|
|
curr_logits_buffer;
|
|
for (int32_t j = 0; j < vec_num; ++j) {
|
|
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
|
|
vec = vec * inv_softcap_scale_vec;
|
|
|
|
#ifdef DEFINE_FAST_EXP
|
|
vec = fast_exp(vec);
|
|
vec_op::FP32Vec16 inv_vec = ones_vec / vec;
|
|
vec = (vec - inv_vec) / (vec + inv_vec);
|
|
#else
|
|
vec.save(curr_logits_buffer_iter);
|
|
for (int k = 0; k < 16; ++k) {
|
|
curr_logits_buffer_iter[k] = std::tanh(curr_logits_buffer_iter[k]);
|
|
}
|
|
vec = vec_op::FP32Vec16(curr_logits_buffer_iter);
|
|
#endif
|
|
vec = vec * softcap_scale_vec;
|
|
vec.save(curr_logits_buffer_iter);
|
|
|
|
curr_logits_buffer_iter += 16;
|
|
}
|
|
|
|
curr_logits_buffer += logits_buffer_stride;
|
|
}
|
|
}
|
|
|
|
void apply_alibi_slopes(logits_buffer_t* __restrict__ logits_buffer,
|
|
const float* __restrict__ alibi_slopes,
|
|
const int64_t logits_buffer_stride,
|
|
const int32_t q_tile_start_pos,
|
|
const int32_t kv_tile_start_pos,
|
|
const int32_t q_token_num,
|
|
const int32_t kv_tile_token_num,
|
|
const int32_t q_heads_per_kv) {
|
|
alignas(64) constexpr float initial_arange_vals[16] = {
|
|
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
|
|
8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f};
|
|
const int32_t vec_num = kv_tile_token_num / 16;
|
|
|
|
vec_op::FP32Vec16 initial_arange_vals_vec(initial_arange_vals);
|
|
initial_arange_vals_vec =
|
|
initial_arange_vals_vec + vec_op::FP32Vec16((float)kv_tile_start_pos);
|
|
vec_op::FP32Vec16 pos_offset_vec(16.0);
|
|
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
|
|
for (int32_t i = 0; i < q_token_num; ++i) {
|
|
vec_op::FP32Vec16 curr_q_pos_vec((float)(i + q_tile_start_pos));
|
|
for (int32_t j = 0; j < q_heads_per_kv; ++j) {
|
|
vec_op::FP32Vec16 alibi_scale_vec(alibi_slopes[j]);
|
|
vec_op::FP32Vec16 curr_kv_pos_vec(initial_arange_vals_vec);
|
|
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
|
|
curr_logits_buffer;
|
|
for (int32_t k = 0; k < vec_num; ++k) {
|
|
vec_op::FP32Vec16 alibi_bias_vec =
|
|
alibi_scale_vec * (curr_kv_pos_vec - curr_q_pos_vec);
|
|
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
|
|
vec = vec + alibi_bias_vec;
|
|
|
|
vec.save(curr_logits_buffer_iter);
|
|
|
|
curr_kv_pos_vec = curr_kv_pos_vec + pos_offset_vec;
|
|
curr_logits_buffer_iter += 16;
|
|
}
|
|
curr_logits_buffer += logits_buffer_stride;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
public:
|
|
void operator()(const AttentionInput* input) {
|
|
const int thread_num = omp_get_max_threads();
|
|
TORCH_CHECK_EQ(input->metadata->thread_num, thread_num);
|
|
std::atomic<int32_t> guard_counter(0);
|
|
std::atomic<int32_t>* guard_counter_ptr = &guard_counter;
|
|
|
|
#pragma omp parallel for schedule(static, 1)
|
|
for (int thread_id = 0; thread_id < thread_num; ++thread_id) {
|
|
AttentionMetadata& metadata = *input->metadata;
|
|
if (metadata.workitem_group_num == 0) {
|
|
continue;
|
|
}
|
|
|
|
attention_impl_t attn_impl;
|
|
|
|
// general information
|
|
const int32_t q_head_num = input->num_heads;
|
|
const int32_t kv_head_num = input->num_kv_heads;
|
|
const int32_t q_heads_per_kv = q_head_num / kv_head_num;
|
|
const bool use_gqa =
|
|
(max_q_head_num_per_iter % q_heads_per_kv == 0) ? true : false;
|
|
const int32_t actual_kv_head_num = use_gqa ? kv_head_num : q_head_num;
|
|
const int32_t actual_q_heads_per_kv = use_gqa ? q_heads_per_kv : 1;
|
|
TORCH_CHECK_LE(actual_q_heads_per_kv, max_q_head_num_per_iter);
|
|
const int32_t max_q_token_num_per_iter =
|
|
max_q_head_num_per_iter / actual_q_heads_per_kv;
|
|
const int64_t q_token_num_stride = input->query_num_tokens_stride;
|
|
const int64_t q_head_num_stride = input->query_num_heads_stride;
|
|
const int64_t kv_cache_head_num_stride = input->cache_num_kv_heads_stride;
|
|
const int64_t kv_cache_block_num_stride = input->cache_num_blocks_stride;
|
|
const int32_t sliding_window_left = input->sliding_window_left;
|
|
const int32_t sliding_window_right = input->sliding_window_right;
|
|
const int32_t block_size = input->block_size;
|
|
const float scale = input->scale;
|
|
const float softcap_scale = input->softcap;
|
|
const float* alibi_slopes = input->alibi_slopes;
|
|
const c10::BFloat16* s_aux = input->s_aux;
|
|
|
|
const bool casual = input->causal;
|
|
int32_t* const block_table = input->block_table;
|
|
const int64_t block_table_stride = input->blt_num_tokens_stride;
|
|
|
|
// init buffers
|
|
void* scratchpad_ptr =
|
|
DNNLScratchPadManager::get_dnnl_scratchpad_manager()
|
|
->get_data<void>();
|
|
AttentionScratchPad buffer_manager(thread_id, metadata, scratchpad_ptr);
|
|
|
|
const int32_t total_reduction_split_num = metadata.reduction_split_num;
|
|
if (metadata.reduction_split_num > 0) {
|
|
// reset split flag
|
|
for (int32_t head_idx = thread_id; head_idx < actual_kv_head_num;
|
|
head_idx += thread_num) {
|
|
buffer_manager.update(head_idx, total_reduction_split_num, head_dim,
|
|
0, sizeof(partial_output_buffer_t));
|
|
volatile bool* __restrict__ curr_flag_ptr =
|
|
buffer_manager.get_reduce_flag_buffer();
|
|
for (int32_t split_idx = 0; split_idx < total_reduction_split_num;
|
|
++split_idx) {
|
|
curr_flag_ptr[split_idx] = false;
|
|
}
|
|
}
|
|
}
|
|
|
|
const int64_t available_cache_size =
|
|
AttentionScheduler::get_available_l2_size();
|
|
const int32_t default_tile_size =
|
|
AttentionScheduler::calcu_default_tile_size(
|
|
available_cache_size, head_dim, sizeof(kv_cache_t),
|
|
sizeof(q_buffer_t), sizeof(logits_buffer_t),
|
|
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
|
|
max_q_head_num_per_iter);
|
|
const int32_t default_q_tile_token_num =
|
|
default_tile_size / actual_q_heads_per_kv;
|
|
|
|
AttentionWorkItemGroup* const workitem_groups =
|
|
metadata.workitem_groups_ptr;
|
|
const int32_t* cu_workitem_num_per_thread =
|
|
metadata.cu_workitem_num_per_thread;
|
|
ReductionWorkItemGroup* const reduction_items =
|
|
metadata.reduction_items_ptr;
|
|
|
|
const int32_t effective_thread_num = metadata.effective_thread_num;
|
|
const int32_t reduction_item_num = metadata.reduction_item_num;
|
|
const int32_t split_kv_q_token_num_threshold =
|
|
metadata.split_kv_q_token_num_threshold;
|
|
const int32_t workitem_groups_counter_num =
|
|
actual_kv_head_num * effective_thread_num;
|
|
const int32_t reduction_items_counter_num =
|
|
actual_kv_head_num * reduction_item_num;
|
|
const int32_t total_counter_num =
|
|
workitem_groups_counter_num + reduction_items_counter_num;
|
|
|
|
if (metadata.reduction_split_num > 0) {
|
|
++(*guard_counter_ptr);
|
|
while (guard_counter_ptr->load() != thread_num) {
|
|
#ifdef FAST_SPINNING
|
|
FAST_SPINNING
|
|
#else
|
|
std::this_thread::yield();
|
|
#endif
|
|
}
|
|
}
|
|
|
|
// main loop
|
|
for (;;) {
|
|
int64_t task_idx = metadata.acquire_counter();
|
|
|
|
if (task_idx >= total_counter_num) {
|
|
// no more tasks, leave loop
|
|
break;
|
|
}
|
|
|
|
if (task_idx < workitem_groups_counter_num) {
|
|
// attention task
|
|
// map task_idx to workitem_groups
|
|
const int32_t kv_head_idx = task_idx / effective_thread_num;
|
|
const int32_t thread_offset = task_idx % effective_thread_num;
|
|
AttentionWorkItemGroup* const curr_workitem_groups =
|
|
workitem_groups + cu_workitem_num_per_thread[thread_offset];
|
|
const int32_t curr_workitem_groups_num =
|
|
cu_workitem_num_per_thread[thread_offset + 1] -
|
|
cu_workitem_num_per_thread[thread_offset];
|
|
|
|
const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv;
|
|
|
|
for (int32_t workitem_group_idx = 0;
|
|
workitem_group_idx < curr_workitem_groups_num;
|
|
++workitem_group_idx) {
|
|
AttentionWorkItemGroup* const current_workitem_group =
|
|
&curr_workitem_groups[workitem_group_idx];
|
|
|
|
const int32_t current_group_idx = current_workitem_group->req_id;
|
|
const int32_t kv_start_pos =
|
|
current_workitem_group->kv_split_pos_start;
|
|
const int32_t kv_end_pos = current_workitem_group->kv_split_pos_end;
|
|
const int32_t curr_spilt_id = current_workitem_group->split_id;
|
|
const int32_t q_token_id_start =
|
|
current_workitem_group->q_token_id_start;
|
|
const int32_t q_token_num = current_workitem_group->q_token_num;
|
|
|
|
// taskgroup general information
|
|
const int32_t q_end = input->query_start_loc[current_group_idx + 1];
|
|
const int32_t q_start = input->query_start_loc[current_group_idx];
|
|
const int32_t seq_len = input->seq_lens[current_group_idx];
|
|
const int32_t q_start_pos =
|
|
(casual ? seq_len - (q_end - q_start) : 0);
|
|
const int32_t block_num = (seq_len + block_size - 1) / block_size;
|
|
// Only apply sink for the first KV split
|
|
bool use_sink = (s_aux != nullptr &&
|
|
current_workitem_group->local_split_id == 0);
|
|
|
|
for (int32_t q_token_offset = 0; q_token_offset < q_token_num;
|
|
q_token_offset += default_q_tile_token_num) {
|
|
bool first_iter_flag[AttentionScheduler::MaxQTileIterNum];
|
|
for (int32_t i = 0; i < AttentionScheduler::MaxQTileIterNum;
|
|
++i) {
|
|
first_iter_flag[i] = true;
|
|
}
|
|
|
|
const int32_t q_token_start_idx =
|
|
q_start + q_token_offset + q_token_id_start;
|
|
const int32_t actual_q_token_num = std::min(
|
|
default_q_tile_token_num, q_token_num - q_token_offset);
|
|
const int32_t q_head_tile_size =
|
|
actual_q_token_num * actual_q_heads_per_kv;
|
|
const int32_t rounded_q_head_tile_size =
|
|
((q_head_tile_size + max_q_head_num_per_iter - 1) /
|
|
max_q_head_num_per_iter) *
|
|
max_q_head_num_per_iter;
|
|
const int32_t kv_tile_size =
|
|
AttentionScheduler::calcu_tile_size_with_constant_q(
|
|
available_cache_size, head_dim, sizeof(kv_cache_t),
|
|
sizeof(q_buffer_t), sizeof(logits_buffer_t),
|
|
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
|
|
blocksize_alignment, rounded_q_head_tile_size,
|
|
rounded_q_head_tile_size <= max_q_head_num_per_iter);
|
|
|
|
// update buffers
|
|
buffer_manager.update(
|
|
head_dim, sizeof(q_buffer_t), sizeof(logits_buffer_t),
|
|
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
|
|
rounded_q_head_tile_size, kv_tile_size);
|
|
q_buffer_t* q_buffer = buffer_manager.get_q_buffer<q_buffer_t>();
|
|
float* logits_buffer = buffer_manager.get_logits_buffer();
|
|
float* partial_q_buffer = buffer_manager.get_output_buffer();
|
|
float* max_buffer = buffer_manager.get_max_buffer();
|
|
float* sum_buffer = buffer_manager.get_sum_buffer();
|
|
|
|
const int32_t q_tile_start_pos =
|
|
q_start_pos + q_token_offset + q_token_id_start;
|
|
const int32_t q_tile_end_pos =
|
|
q_tile_start_pos + actual_q_token_num;
|
|
const auto [kv_tile_start_pos, kv_tile_end_pos] =
|
|
AttentionScheduler::calcu_kv_tile_pos(
|
|
kv_start_pos, kv_end_pos, q_tile_start_pos,
|
|
q_tile_end_pos, sliding_window_left,
|
|
sliding_window_right);
|
|
const auto [rounded_kv_tile_start_pos, rounded_kv_tile_end_pos] =
|
|
AttentionScheduler::align_kv_tile_pos(
|
|
kv_tile_start_pos, kv_tile_end_pos, blocksize_alignment);
|
|
|
|
int32_t curr_kv_head_idx =
|
|
use_gqa ? kv_head_idx
|
|
: (kv_head_idx /
|
|
q_heads_per_kv); // for GQA disabled case
|
|
|
|
// std::printf("thread_id: %d, req_id: %d, q_token_start: %d,
|
|
// q_token_end: %d, q_head_start: %d, q_head_end: %d, kv_head_idx:
|
|
// %d, kv_pos_start: %d, kv_pos_end: %d\n",
|
|
// thread_id, current_group_idx,
|
|
// q_token_start_idx, q_token_start_idx +
|
|
// actual_q_token_num, q_head_start_idx,
|
|
// q_head_start_idx + actual_q_heads_per_kv,
|
|
// curr_kv_head_idx, kv_tile_start_pos,
|
|
// kv_tile_end_pos);
|
|
|
|
// move buffers
|
|
kv_cache_t* curr_k_cache =
|
|
reinterpret_cast<kv_cache_t*>(input->key_cache) +
|
|
curr_kv_head_idx * kv_cache_head_num_stride;
|
|
kv_cache_t* curr_v_cache =
|
|
reinterpret_cast<kv_cache_t*>(input->value_cache) +
|
|
curr_kv_head_idx * kv_cache_head_num_stride;
|
|
query_t* const q_tile_ptr =
|
|
reinterpret_cast<query_t*>(input->query) +
|
|
q_token_start_idx * q_token_num_stride +
|
|
q_head_start_idx * q_head_num_stride;
|
|
size_t output_buffer_offset =
|
|
q_token_start_idx * q_head_num * head_dim +
|
|
q_head_start_idx * head_dim;
|
|
int32_t* curr_block_table =
|
|
block_table + current_group_idx * block_table_stride;
|
|
const float* curr_alibi_slopes =
|
|
(alibi_slopes != nullptr ? alibi_slopes + q_head_start_idx
|
|
: nullptr);
|
|
const c10::BFloat16* curr_s_aux =
|
|
(s_aux != nullptr ? s_aux + q_head_start_idx : nullptr);
|
|
|
|
// copy the Q tile to q_buffer, the logical layout of q_buffer is
|
|
// [actual_q_token_num, actual_q_heads_per_kv, head_dim]
|
|
{
|
|
attn_impl.copy_q_heads_tile(
|
|
q_tile_ptr, q_buffer, actual_q_token_num,
|
|
actual_q_heads_per_kv, q_token_num_stride,
|
|
q_head_num_stride, scale);
|
|
}
|
|
|
|
if (use_sink) {
|
|
alignas(64) float s_aux_fp32[16];
|
|
#if defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
|
|
// ARM without native BF16 support: manual conversion
|
|
for (int i = 0; i < 16; ++i) {
|
|
s_aux_fp32[i] = static_cast<float>(curr_s_aux[i]);
|
|
}
|
|
#else
|
|
// All other platforms have BF16Vec16 available
|
|
vec_op::BF16Vec16 vec_bf16(curr_s_aux);
|
|
vec_op::FP32Vec16 vec_fp32(vec_bf16);
|
|
vec_fp32.save(s_aux_fp32);
|
|
#endif
|
|
|
|
float* __restrict__ curr_sum_buffer = sum_buffer;
|
|
float* __restrict__ curr_max_buffer = max_buffer;
|
|
for (int32_t token_idx = 0; token_idx < actual_q_token_num;
|
|
++token_idx) {
|
|
for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv;
|
|
++head_idx) {
|
|
curr_sum_buffer[head_idx] = 1.0f;
|
|
curr_max_buffer[head_idx] = s_aux_fp32[head_idx];
|
|
}
|
|
|
|
curr_sum_buffer += actual_q_heads_per_kv;
|
|
curr_max_buffer += actual_q_heads_per_kv;
|
|
}
|
|
} else {
|
|
float* __restrict__ curr_sum_buffer = sum_buffer;
|
|
float* __restrict__ curr_max_buffer = max_buffer;
|
|
for (int32_t token_idx = 0; token_idx < actual_q_token_num;
|
|
++token_idx) {
|
|
for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv;
|
|
++head_idx) {
|
|
curr_sum_buffer[head_idx] = 0.0f;
|
|
curr_max_buffer[head_idx] =
|
|
std::numeric_limits<float>::lowest();
|
|
}
|
|
|
|
curr_sum_buffer += actual_q_heads_per_kv;
|
|
curr_max_buffer += actual_q_heads_per_kv;
|
|
}
|
|
}
|
|
|
|
// compute loop
|
|
for (int32_t kv_tile_pos = rounded_kv_tile_start_pos;
|
|
kv_tile_pos < rounded_kv_tile_end_pos;
|
|
kv_tile_pos += kv_tile_size) {
|
|
const int32_t kv_tile_pos_left = kv_tile_pos;
|
|
const int32_t kv_tile_pos_right = std::min(
|
|
kv_tile_pos_left + kv_tile_size, rounded_kv_tile_end_pos);
|
|
for (int32_t q_head_tile_token_offset = 0;
|
|
q_head_tile_token_offset < actual_q_token_num;
|
|
q_head_tile_token_offset += max_q_token_num_per_iter) {
|
|
const int32_t q_tile_pos_left =
|
|
q_tile_start_pos + q_head_tile_token_offset;
|
|
const int32_t q_tile_token_num =
|
|
std::min(max_q_token_num_per_iter,
|
|
actual_q_token_num - q_head_tile_token_offset);
|
|
const int32_t q_tile_head_offset =
|
|
q_head_tile_token_offset * actual_q_heads_per_kv;
|
|
const int32_t q_tile_head_num =
|
|
q_tile_token_num * actual_q_heads_per_kv;
|
|
const int32_t q_tile_pos_right =
|
|
q_tile_pos_left + q_tile_token_num;
|
|
const auto [actual_kv_tile_pos_left,
|
|
actual_kv_tile_pos_right] =
|
|
AttentionScheduler::calcu_kv_tile_pos(
|
|
kv_tile_pos_left, kv_tile_pos_right, q_tile_pos_left,
|
|
q_tile_pos_right, sliding_window_left,
|
|
sliding_window_right);
|
|
const int32_t q_iter_idx =
|
|
q_head_tile_token_offset / max_q_token_num_per_iter;
|
|
|
|
if (actual_kv_tile_pos_right <= actual_kv_tile_pos_left) {
|
|
continue;
|
|
}
|
|
|
|
// align kv_pos to blocksize_alignment
|
|
const auto [aligned_actual_kv_tile_pos_left,
|
|
aligned_actual_kv_tile_pos_right] =
|
|
AttentionScheduler::align_kv_tile_pos(
|
|
actual_kv_tile_pos_left, actual_kv_tile_pos_right,
|
|
blocksize_alignment);
|
|
const int32_t actual_kv_token_num =
|
|
aligned_actual_kv_tile_pos_right -
|
|
aligned_actual_kv_tile_pos_left;
|
|
|
|
// std::printf("\tq_iter_idx: %d, q_token_start: %d,
|
|
// q_token_end: %d, q_token_num: %d, q_head_num: %d,
|
|
// q_pos_start: %d, q_pos_end: %d, kv_pos_start: %d,
|
|
// kv_pos_end: %d\n",
|
|
// q_iter_idx, q_token_start_idx +
|
|
// q_head_tile_token_offset, q_token_start_idx +
|
|
// q_head_tile_token_offset + q_tile_token_num,
|
|
// q_tile_token_num, q_tile_head_num,
|
|
// q_tile_pos_left, q_tile_pos_right,
|
|
// aligned_actual_kv_tile_pos_left,
|
|
// aligned_actual_kv_tile_pos_right);
|
|
|
|
// Move buffers
|
|
q_buffer_t* curr_q_heads_buffer =
|
|
q_buffer + q_tile_head_offset * head_dim;
|
|
float* curr_partial_q_buffer =
|
|
partial_q_buffer + q_tile_head_offset * head_dim;
|
|
float* curr_max_buffer = max_buffer + q_tile_head_offset;
|
|
float* curr_sum_buffer = sum_buffer + q_tile_head_offset;
|
|
|
|
bool debug_info = false;
|
|
// bool debug_info = (
|
|
// q_head_start_idx == 4 &&
|
|
// (q_token_start_idx + q_head_tile_token_offset) <=
|
|
// 4
|
|
// && (q_token_start_idx + q_head_tile_token_offset +
|
|
// q_tile_token_num) > 4
|
|
// );
|
|
// if (debug_info) {
|
|
// std::printf("\tq_iter_idx: %d, q_token_start: %d,"
|
|
// "q_token_end: %d, q_token_num: %d, q_head_num: %d,"
|
|
// "q_pos_start: %d, q_pos_end: %d, kv_pos_start: %d,"
|
|
// "kv_pos_end: %d\n",
|
|
// q_iter_idx, q_token_start_idx +
|
|
// q_head_tile_token_offset, q_token_start_idx
|
|
// + q_head_tile_token_offset +
|
|
// q_tile_token_num, q_tile_token_num,
|
|
// q_tile_head_num, q_tile_pos_left,
|
|
// q_tile_pos_right,
|
|
// aligned_actual_kv_tile_pos_left,
|
|
// aligned_actual_kv_tile_pos_right);
|
|
// }
|
|
|
|
attn_impl.template execute_attention<Attention>(
|
|
curr_q_heads_buffer, curr_k_cache, curr_v_cache,
|
|
logits_buffer, curr_partial_q_buffer, curr_max_buffer,
|
|
curr_sum_buffer, curr_block_table,
|
|
aligned_actual_kv_tile_pos_left,
|
|
aligned_actual_kv_tile_pos_right, actual_kv_token_num,
|
|
kv_cache_block_num_stride, q_tile_head_num,
|
|
q_tile_token_num, q_tile_pos_left, actual_q_heads_per_kv,
|
|
block_size, sliding_window_left, sliding_window_right,
|
|
scale, softcap_scale, curr_alibi_slopes,
|
|
first_iter_flag[q_iter_idx], use_sink, debug_info);
|
|
first_iter_flag[q_iter_idx] = false;
|
|
}
|
|
}
|
|
|
|
// write back partial results to output buffer or reduction buffer
|
|
{
|
|
if (curr_spilt_id == -1) {
|
|
final_output(partial_q_buffer,
|
|
reinterpret_cast<query_t*>(input->output) +
|
|
output_buffer_offset,
|
|
sum_buffer, actual_q_heads_per_kv,
|
|
actual_q_token_num, q_head_num);
|
|
} else {
|
|
const int32_t stride =
|
|
actual_q_heads_per_kv * split_kv_q_token_num_threshold;
|
|
buffer_manager.update(kv_head_idx, total_reduction_split_num,
|
|
head_dim, stride, sizeof(float));
|
|
volatile bool* split_flag_buffer =
|
|
buffer_manager.get_reduce_flag_buffer() + curr_spilt_id;
|
|
float* split_output_buffer =
|
|
buffer_manager.get_reduce_output_buffer() +
|
|
curr_spilt_id * stride * head_dim;
|
|
float* split_max_buffer =
|
|
buffer_manager.get_reduce_max_buffer() +
|
|
curr_spilt_id * stride;
|
|
float* split_sum_buffer =
|
|
buffer_manager.get_reduce_sum_buffer() +
|
|
curr_spilt_id * stride;
|
|
|
|
partial_output(partial_q_buffer, max_buffer, sum_buffer,
|
|
q_head_tile_size, split_output_buffer,
|
|
split_max_buffer, split_sum_buffer,
|
|
split_flag_buffer);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
task_idx -= workitem_groups_counter_num;
|
|
const int32_t kv_head_idx = task_idx / reduction_item_num;
|
|
const int32_t item_offset = task_idx % reduction_item_num;
|
|
ReductionWorkItemGroup* const curr_workitem_groups =
|
|
reduction_items + item_offset;
|
|
const int32_t curr_output_token_idx =
|
|
curr_workitem_groups->q_token_id_start;
|
|
const int32_t curr_output_token_num =
|
|
curr_workitem_groups->q_token_id_num;
|
|
const int32_t curr_split_id = curr_workitem_groups->split_start_id;
|
|
const int32_t curr_split_num = curr_workitem_groups->split_num;
|
|
const int32_t current_group_idx = curr_workitem_groups->req_id;
|
|
const int32_t curr_output_head_num =
|
|
curr_output_token_num * actual_q_heads_per_kv;
|
|
|
|
const int32_t q_start = input->query_start_loc[current_group_idx];
|
|
const int32_t q_token_start_idx = q_start + curr_output_token_idx;
|
|
const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv;
|
|
size_t output_buffer_offset =
|
|
q_token_start_idx * q_head_num * head_dim +
|
|
q_head_start_idx * head_dim;
|
|
|
|
const int32_t stride =
|
|
actual_q_heads_per_kv * split_kv_q_token_num_threshold;
|
|
buffer_manager.update(kv_head_idx, total_reduction_split_num,
|
|
head_dim, stride, sizeof(float));
|
|
volatile bool* split_flag_buffer =
|
|
buffer_manager.get_reduce_flag_buffer() + curr_split_id;
|
|
float* split_output_buffer =
|
|
buffer_manager.get_reduce_output_buffer() +
|
|
curr_split_id * stride * head_dim;
|
|
float* split_max_buffer =
|
|
buffer_manager.get_reduce_max_buffer() + curr_split_id * stride;
|
|
float* split_sum_buffer =
|
|
buffer_manager.get_reduce_sum_buffer() + curr_split_id * stride;
|
|
|
|
reduce_splits(split_output_buffer, split_max_buffer, split_sum_buffer,
|
|
split_flag_buffer, stride, curr_output_head_num,
|
|
curr_split_num);
|
|
final_output(
|
|
split_output_buffer,
|
|
reinterpret_cast<query_t*>(input->output) + output_buffer_offset,
|
|
split_sum_buffer, actual_q_heads_per_kv, curr_output_token_num,
|
|
q_head_num);
|
|
}
|
|
}
|
|
}
|
|
// Reset counter for next call
|
|
input->metadata->reset_counter();
|
|
}
|
|
|
|
void reduce_splits(float* __restrict__ split_output_buffer,
|
|
float* __restrict__ split_max_buffer,
|
|
float* __restrict__ split_sum_buffer,
|
|
volatile bool* __restrict__ flags,
|
|
const int32_t head_num_per_split,
|
|
const int32_t curr_head_num, const int32_t split_num) {
|
|
#ifdef DEFINE_FAST_EXP
|
|
DEFINE_FAST_EXP
|
|
#endif
|
|
// restrict curr_head_num <= 16 in the scheduler
|
|
// elems in split_max_buffer, split_sum_buffer are not cache alignment, use
|
|
// local buffers to reduce false-sharing
|
|
alignas(64) float local_max[16];
|
|
alignas(64) float local_sum[16];
|
|
|
|
float* __restrict__ curr_split_output_buffer = split_output_buffer;
|
|
float* __restrict__ curr_split_max_buffer = split_max_buffer;
|
|
float* __restrict__ curr_split_sum_buffer = split_sum_buffer;
|
|
constexpr int32_t head_dim_group_num = head_dim / 16;
|
|
for (int32_t split_idx = 0; split_idx < split_num; ++split_idx) {
|
|
while (!flags[split_idx]) {
|
|
#ifdef FAST_SPINNING
|
|
FAST_SPINNING
|
|
#else
|
|
std::this_thread::yield();
|
|
#endif
|
|
}
|
|
std::atomic_thread_fence(std::memory_order_acquire);
|
|
|
|
if (split_idx > 0) {
|
|
float* __restrict__ curr_output_buffer = split_output_buffer;
|
|
float* __restrict__ curr_split_output_buffer_iter =
|
|
curr_split_output_buffer;
|
|
for (int32_t head_idx = 0; head_idx < curr_head_num; ++head_idx) {
|
|
float final_max = local_max[head_idx];
|
|
float curr_max = curr_split_max_buffer[head_idx];
|
|
float final_sum = local_sum[head_idx];
|
|
float curr_sum = curr_split_sum_buffer[head_idx];
|
|
float* __restrict__ non_scale_output_iter =
|
|
final_max > curr_max ? curr_output_buffer
|
|
: curr_split_output_buffer_iter;
|
|
float* __restrict__ scale_output_iter =
|
|
final_max > curr_max ? curr_split_output_buffer_iter
|
|
: curr_output_buffer;
|
|
float rescale_factor = final_max > curr_max ? curr_max - final_max
|
|
: final_max - curr_max;
|
|
rescale_factor = std::exp(rescale_factor);
|
|
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
|
|
|
|
local_sum[head_idx] = final_max > curr_max
|
|
? final_sum + rescale_factor * curr_sum
|
|
: rescale_factor * final_sum + curr_sum;
|
|
|
|
final_max = std::max(final_max, curr_max);
|
|
local_max[head_idx] = final_max;
|
|
for (int32_t i = 0; i < head_dim_group_num; ++i) {
|
|
vec_op::FP32Vec16 non_scale_vec(non_scale_output_iter);
|
|
vec_op::FP32Vec16 scale_vec(scale_output_iter);
|
|
vec_op::FP32Vec16 final_vec =
|
|
non_scale_vec + scale_vec * rescale_factor_vec;
|
|
final_vec.save(curr_output_buffer);
|
|
|
|
non_scale_output_iter += 16;
|
|
scale_output_iter += 16;
|
|
curr_output_buffer += 16;
|
|
}
|
|
curr_split_output_buffer_iter += head_dim;
|
|
}
|
|
} else {
|
|
vec_op::FP32Vec16 final_max(split_max_buffer);
|
|
final_max.save(local_max);
|
|
vec_op::FP32Vec16 final_sum(split_sum_buffer);
|
|
final_sum.save(local_sum);
|
|
}
|
|
|
|
curr_split_output_buffer += head_num_per_split * head_dim;
|
|
curr_split_max_buffer += head_num_per_split;
|
|
curr_split_sum_buffer += head_num_per_split;
|
|
}
|
|
// write back final max and sum
|
|
for (int32_t i = 0; i < curr_head_num; ++i) {
|
|
split_max_buffer[i] = local_max[i];
|
|
split_sum_buffer[i] = local_sum[i];
|
|
}
|
|
}
|
|
|
|
void partial_output(float* __restrict__ partial_output_buffer,
|
|
float* __restrict__ partial_max_buffer,
|
|
float* __restrict__ partial_sum_buffer,
|
|
int32_t curr_head_num,
|
|
float* __restrict__ split_output_buffer,
|
|
float* __restrict__ split_max_buffer,
|
|
float* __restrict__ split_sum_buffer,
|
|
volatile bool* __restrict__ flag) {
|
|
float* __restrict__ curr_partial_output_buffer = partial_output_buffer;
|
|
float* __restrict__ curr_split_output_buffer = split_output_buffer;
|
|
constexpr int32_t head_dim_group_num = head_dim / 16;
|
|
for (int32_t i = 0; i < curr_head_num; ++i) {
|
|
split_max_buffer[i] = partial_max_buffer[i];
|
|
split_sum_buffer[i] = partial_sum_buffer[i];
|
|
for (int32_t j = 0; j < head_dim_group_num; ++j) {
|
|
vec_op::FP32Vec16 vec(curr_partial_output_buffer);
|
|
vec.save(curr_split_output_buffer);
|
|
|
|
curr_partial_output_buffer += 16;
|
|
curr_split_output_buffer += 16;
|
|
}
|
|
}
|
|
std::atomic_thread_fence(std::memory_order_release);
|
|
*flag = true;
|
|
}
|
|
|
|
void final_output(float* __restrict__ partial_q_buffer,
|
|
query_t* __restrict__ curr_output_buffer,
|
|
float* __restrict__ sum_buffer,
|
|
const int32_t q_heads_per_kv,
|
|
const int32_t actual_q_token_num,
|
|
const int32_t q_head_num) {
|
|
// final output
|
|
using output_vec_t = typename VecTypeTrait<query_t>::vec_t;
|
|
|
|
float* __restrict__ curr_partial_output_buffer = partial_q_buffer;
|
|
float* __restrict__ curr_sum_buffer = sum_buffer;
|
|
constexpr int32_t group_num_per_head = head_dim / 16;
|
|
const int32_t partial_q_buffer_stride = q_heads_per_kv * head_dim;
|
|
const int32_t output_buffer_stride = q_head_num * head_dim;
|
|
for (int32_t token_idx = 0; token_idx < actual_q_token_num; ++token_idx) {
|
|
float* __restrict__ curr_partial_output_buffer_iter =
|
|
curr_partial_output_buffer;
|
|
query_t* __restrict__ curr_output_buffer_iter = curr_output_buffer;
|
|
for (int32_t head_idx = 0; head_idx < q_heads_per_kv; ++head_idx) {
|
|
vec_op::FP32Vec16 inv_sum_scale_vec(1.0 / *curr_sum_buffer);
|
|
|
|
for (int32_t i = 0; i < group_num_per_head; ++i) {
|
|
vec_op::FP32Vec16 vec(curr_partial_output_buffer_iter);
|
|
// divide the final sum val of softmax here
|
|
vec = inv_sum_scale_vec * vec;
|
|
|
|
// cast to query type
|
|
output_vec_t output_vec(vec);
|
|
output_vec.save(curr_output_buffer_iter);
|
|
|
|
// update
|
|
curr_partial_output_buffer_iter += 16;
|
|
curr_output_buffer_iter += 16;
|
|
}
|
|
|
|
// update
|
|
curr_sum_buffer += 1;
|
|
}
|
|
|
|
// update
|
|
curr_partial_output_buffer += partial_q_buffer_stride;
|
|
curr_output_buffer += output_buffer_stride;
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace cpu_attention
|
|
|
|
#endif
|