Files
vllm/csrc/cpu/cpu_attn_impl.hpp

2001 lines
83 KiB
C++

#ifndef CPU_ATTN_HPP
#define CPU_ATTN_HPP
#include <type_traits>
#include <cstddef>
#if defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "cpu_attn_macros.h"
#include "utils.hpp"
namespace cpu_attention {
enum class ISA { AMX, VEC, VEC16, NEON };
template <ISA isa, typename scalar_t, int64_t head_dim>
class AttentionImpl {};
struct AttentionWorkItemGroup {
int32_t req_id;
int32_t q_token_id_start;
int32_t q_token_num;
int32_t kv_split_pos_start;
int32_t kv_split_pos_end;
int64_t total_kv_len;
int32_t split_id;
int32_t local_split_id;
AttentionWorkItemGroup(const int32_t req_id, const int32_t q_token_id_start,
const int32_t kv_split_pos_start,
const int32_t kv_split_pos_end)
: req_id(req_id),
q_token_id_start(q_token_id_start),
q_token_num(0),
kv_split_pos_start(kv_split_pos_start),
kv_split_pos_end(kv_split_pos_end),
total_kv_len(0),
split_id(-1),
local_split_id(0) {}
std::string to_string() const {
std::stringstream ss;
ss << '[' << "req_id: " << req_id << ",\n";
ss << "q_token_id_start: " << q_token_id_start << ",\n";
ss << "q_token_num: " << q_token_num << ",\n";
ss << "kv_split_pos_start: " << kv_split_pos_start << ",\n";
ss << "kv_split_pos_end: " << kv_split_pos_end << ",\n";
ss << "total_kv_len: " << total_kv_len << ",\n";
ss << "split_id: " << split_id << ",\n";
ss << "local_split_id: " << local_split_id << ",\n";
ss << ']';
return ss.str();
}
};
struct ReductionWorkItemGroup {
int32_t req_id;
int32_t q_token_id_start;
int32_t q_token_id_num;
int32_t split_start_id;
int32_t split_num;
ReductionWorkItemGroup(const int32_t req_id, const int32_t q_token_id_start,
const int32_t q_token_id_num,
const int32_t split_start_id)
: req_id(req_id),
q_token_id_start(q_token_id_start),
q_token_id_num(q_token_id_num),
split_start_id(split_start_id),
split_num(0) {}
std::string to_string() const {
std::stringstream ss;
ss << '[' << "req_id: " << req_id << ",\n";
ss << "q_token_id_start: " << q_token_id_start << ",\n";
ss << "q_token_id_num: " << q_token_id_num << ",\n";
ss << "split_start_id: " << split_start_id << ",\n";
ss << "split_num: " << split_num << ",\n";
ss << ']';
return ss.str();
}
};
struct AttentionMetadata {
std::atomic_int64_t counter;
char _padding1[56];
ISA isa;
int32_t workitem_group_num;
int32_t reduction_item_num;
int32_t reduction_split_num;
int32_t thread_num;
int32_t effective_thread_num; // non-zero item num in workitem_num_per_thread
int32_t split_kv_q_token_num_threshold;
int64_t attention_scratchpad_size_per_thread;
int64_t reduction_scratchpad_size_per_kv_head;
AttentionWorkItemGroup* workitem_groups_ptr;
ReductionWorkItemGroup* reduction_items_ptr;
int32_t cu_workitem_num_per_thread[1025] = {
0}; // prefix sum of workitem_num_per_thread
char _padding2[56];
AttentionMetadata(ISA isa, int32_t workitem_group_num,
int32_t reduction_item_num, int32_t reduction_split_num,
int32_t split_kv_q_token_num_threshold)
: isa(isa),
workitem_group_num(workitem_group_num),
reduction_item_num(reduction_item_num),
reduction_split_num(reduction_split_num),
thread_num(omp_get_max_threads()),
effective_thread_num(thread_num),
split_kv_q_token_num_threshold(split_kv_q_token_num_threshold),
attention_scratchpad_size_per_thread(0),
reduction_scratchpad_size_per_kv_head(0),
workitem_groups_ptr(
(AttentionWorkItemGroup*)((char*)this + sizeof(AttentionMetadata))),
reduction_items_ptr(
(ReductionWorkItemGroup*)((char*)this + sizeof(AttentionMetadata) +
workitem_group_num *
sizeof(AttentionWorkItemGroup))),
counter(0) {
TORCH_CHECK_LE(thread_num, 1024);
static_assert(sizeof(AttentionMetadata) % 64 == 0);
TORCH_CHECK(reinterpret_cast<size_t>(this) % 64 == 0);
}
void reset_counter() { counter.store(0); }
int64_t acquire_counter() { return counter++; }
void print() const {
std::stringstream ss;
ss << "ISA: ";
switch (isa) {
case ISA::AMX:
ss << "AMX, ";
break;
case ISA::VEC:
ss << "VEC, ";
break;
case ISA::VEC16:
ss << "VEC16, ";
break;
case ISA::NEON:
ss << "NEON, ";
break;
}
ss << "workitem_group_num: " << workitem_group_num
<< ", reduction_item_num: " << reduction_item_num
<< ", reduction_split_num: " << reduction_split_num
<< ", thread_num: " << thread_num
<< ", effective_thread_num: " << effective_thread_num
<< ", attention_scratchpad_size_per_thread: "
<< attention_scratchpad_size_per_thread
<< ", reduction_scratchpad_size_per_kv_head: "
<< reduction_scratchpad_size_per_kv_head << ", workitem groups:\n";
for (int32_t i = 0; i < workitem_group_num; ++i) {
ss << (workitem_groups_ptr + i)->to_string() << ",\n";
}
ss << "cu_workitem_num_per_thread: [";
for (int32_t i = 0; i < thread_num + 1; ++i) {
ss << cu_workitem_num_per_thread[i] << ", ";
}
ss << "]\n";
ss << "reduction items: \n";
for (int32_t i = 0; i < reduction_item_num; ++i) {
ss << (reduction_items_ptr + i)->to_string() << ",\n";
}
std::printf("%s", ss.str().c_str());
}
};
// Thread attention scratchpad contains:
// - Q: q_tile_size * head_dim * q_buffer_elem_size, gather Q heads, especially
// for GQA
// - Q@K^T: max_num_q_per_iter * k_tile_size * logits_buffer_elem_size, logits
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2
// * q_tile_size * 4, partial output, max + sum (float)
// Reduction scratchpad contains:
// - flags: bool array to indicate wether the split is finished
// - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size
// - max, sum: 2 * split_num * q_tile_size * 4
class AttentionScratchPad {
public:
AttentionScratchPad(int64_t thread_id,
const AttentionMetadata& attention_metadata,
void* scratchpad_ptr)
: thread_scratchpad_ptr(
static_cast<int8_t*>(scratchpad_ptr) +
thread_id *
attention_metadata.attention_scratchpad_size_per_thread),
reduction_scratchpad_ptr(
static_cast<int8_t*>(scratchpad_ptr) +
attention_metadata.thread_num *
attention_metadata.attention_scratchpad_size_per_thread),
reduction_scratchpad_size_per_kv_head(
attention_metadata.reduction_scratchpad_size_per_kv_head) {}
// for attention
void update(const int64_t head_dim, const int64_t q_buffer_elem_size,
const int64_t logits_buffer_elem_size,
const int64_t output_buffer_elem_size,
const int64_t max_num_q_per_iter, const int64_t q_head_tile_size,
const int64_t kv_tile_size) {
int64_t buffer_offset = 0;
q_buffer_offset_ = buffer_offset;
buffer_offset +=
calcu_q_buffer_size(q_head_tile_size, head_dim, q_buffer_elem_size);
logits_buffer_offset_ = buffer_offset;
buffer_offset += calcu_logits_buffer_size(max_num_q_per_iter, kv_tile_size,
logits_buffer_elem_size);
output_buffer_offset_ = buffer_offset;
buffer_offset += calcu_partial_output_buffer_size(
q_head_tile_size, head_dim, output_buffer_elem_size);
max_buffer_offset_ = buffer_offset;
buffer_offset += calcu_partial_output_max_sum_buffer_size(q_head_tile_size);
sum_buffer_offset_ = buffer_offset;
}
// for reduction
void update(const int32_t kv_head_idx, const int32_t total_split_num,
const int64_t head_dim, const int64_t q_head_tile_size,
const int64_t output_buffer_elem_size) {
int64_t buffer_offset = kv_head_idx * reduction_scratchpad_size_per_kv_head;
reduce_flag_buffer_offset_ = buffer_offset;
buffer_offset += calcu_reduce_flag_buffer_size(total_split_num);
reduce_output_buffer_offset_ = buffer_offset;
buffer_offset += calcu_reduce_output_buffer_size(
total_split_num, q_head_tile_size, head_dim, output_buffer_elem_size);
reduce_max_buffer_offset_ = buffer_offset;
buffer_offset +=
calcu_reduce_max_sum_buffer_size(total_split_num, q_head_tile_size);
reduce_sum_buffer_offset_ = buffer_offset;
}
template <typename T>
T* get_q_buffer() {
return reinterpret_cast<T*>(thread_scratchpad_ptr + q_buffer_offset_);
}
float* get_logits_buffer() {
return reinterpret_cast<float*>(thread_scratchpad_ptr +
logits_buffer_offset_);
}
float* get_output_buffer() {
return reinterpret_cast<float*>(thread_scratchpad_ptr +
output_buffer_offset_);
}
float* get_max_buffer() {
return reinterpret_cast<float*>(thread_scratchpad_ptr + max_buffer_offset_);
}
float* get_sum_buffer() {
return reinterpret_cast<float*>(thread_scratchpad_ptr + sum_buffer_offset_);
}
volatile bool* get_reduce_flag_buffer() {
return reinterpret_cast<volatile bool*>(reduction_scratchpad_ptr +
reduce_flag_buffer_offset_);
}
float* get_reduce_output_buffer() {
return reinterpret_cast<float*>(reduction_scratchpad_ptr +
reduce_output_buffer_offset_);
}
float* get_reduce_max_buffer() {
return reinterpret_cast<float*>(reduction_scratchpad_ptr +
reduce_max_buffer_offset_);
}
float* get_reduce_sum_buffer() {
return reinterpret_cast<float*>(reduction_scratchpad_ptr +
reduce_sum_buffer_offset_);
}
int64_t get_thread_scratchpad_size() const {
return 2 * sum_buffer_offset_ - max_buffer_offset_;
}
int64_t get_reduction_scratchpad_size() const {
return 2 * reduce_sum_buffer_offset_ - reduce_max_buffer_offset_;
}
private:
static int64_t round_to_64(const int64_t num) {
return ((num + 63) >> 6) << 6;
}
static int64_t calcu_q_buffer_size(const int64_t q_tile_size,
const int64_t head_dim,
const int64_t elem_size) {
return round_to_64(q_tile_size * head_dim * elem_size);
}
static int64_t calcu_logits_buffer_size(const int64_t max_num_q_per_iter,
const int64_t k_tile_size,
const int64_t elem_size) {
return round_to_64(elem_size * max_num_q_per_iter * k_tile_size);
}
static int64_t calcu_partial_output_buffer_size(const int64_t q_tile_size,
const int64_t head_dim,
const int64_t elem_size) {
return round_to_64(q_tile_size * head_dim * elem_size);
}
static int64_t calcu_partial_output_max_sum_buffer_size(
const int64_t q_tile_size) {
return round_to_64(q_tile_size * sizeof(float));
}
static int64_t calcu_reduce_flag_buffer_size(const int64_t total_split_num) {
return round_to_64(total_split_num * sizeof(bool));
}
static int64_t calcu_reduce_max_sum_buffer_size(
const int64_t total_split_num, const int32_t q_head_tile_size) {
return round_to_64(total_split_num * q_head_tile_size * sizeof(float));
}
static int64_t calcu_reduce_output_buffer_size(
const int64_t total_split_num, const int64_t q_head_tile_size,
const int64_t head_dim, const int64_t output_buffer_elem_size) {
return round_to_64(total_split_num * q_head_tile_size * head_dim *
output_buffer_elem_size);
}
private:
int8_t* thread_scratchpad_ptr;
int8_t* reduction_scratchpad_ptr;
int64_t reduction_scratchpad_size_per_kv_head;
// attention buffers
int64_t q_buffer_offset_;
int64_t logits_buffer_offset_;
int64_t output_buffer_offset_;
int64_t max_buffer_offset_;
int64_t sum_buffer_offset_;
// reduction buffers
int64_t reduce_flag_buffer_offset_;
int64_t reduce_output_buffer_offset_;
int64_t reduce_max_buffer_offset_;
int64_t reduce_sum_buffer_offset_;
};
class AttentionScheduler {
public:
struct ScheduleInput {
int32_t num_reqs;
int32_t elem_size;
int32_t q_buffer_elem_size;
int32_t logits_buffer_elem_size;
int32_t output_buffer_elem_size;
int32_t num_heads_q;
int32_t num_heads_kv;
int32_t head_dim;
int32_t* query_start_loc;
int32_t* seq_lens;
int32_t left_sliding_window_size;
int32_t right_sliding_window_size;
bool casual;
cpu_attention::ISA isa;
int32_t max_num_q_per_iter; // max Q head num can be hold in registers
int32_t kv_block_alignment; // context length alignment requirement
bool enable_kv_split;
};
static constexpr int32_t MaxQTileIterNum = 128;
AttentionScheduler() : available_cache_size_(get_available_l2_size()) {}
torch::Tensor schedule(const ScheduleInput& input) const {
const bool casual = input.casual;
const int32_t thread_num = omp_get_max_threads();
const int64_t cache_size = get_available_l2_size();
const int32_t max_num_q_per_iter = input.max_num_q_per_iter;
const int32_t kv_len_alignment = input.kv_block_alignment;
int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv;
const bool use_gqa = (max_num_q_per_iter % q_head_per_kv == 0);
if (!use_gqa) {
q_head_per_kv = 1; // fallback to MHA
}
const int32_t min_split_kv_len =
((max_num_q_per_iter * 4 + kv_len_alignment - 1) / kv_len_alignment) *
kv_len_alignment;
const int32_t max_num_q_token_per_iter = max_num_q_per_iter / q_head_per_kv;
const int64_t default_tile_size = calcu_default_tile_size(
cache_size, input.head_dim, input.elem_size, input.q_buffer_elem_size,
input.logits_buffer_elem_size, input.output_buffer_elem_size,
max_num_q_per_iter, max_num_q_per_iter);
const int32_t default_tile_token_num = default_tile_size / q_head_per_kv;
const int32_t split_kv_q_token_num_threshold =
input.enable_kv_split ? 1 : 0;
const int32_t left_sliding_window_size = input.left_sliding_window_size;
const int32_t right_sliding_window_size = input.right_sliding_window_size;
TORCH_CHECK_LE(split_kv_q_token_num_threshold * q_head_per_kv, 16);
// get total kv len
int64_t total_kv_len = 0;
for (int32_t req_id = 0; req_id < input.num_reqs; ++req_id) {
const int32_t seq_len = input.seq_lens[req_id];
const int32_t q_token_num =
input.query_start_loc[req_id + 1] - input.query_start_loc[req_id];
const int32_t q_start_pos = (casual ? (seq_len - q_token_num) : 0);
const int32_t kv_start_pos = 0;
const int32_t kv_end_pos = seq_len;
for (int32_t token_id = 0; token_id < q_token_num;
token_id += max_num_q_token_per_iter) {
const int32_t q_tile_token_num =
std::min(max_num_q_token_per_iter, q_token_num - token_id);
const int32_t q_tile_pos_left = q_start_pos + token_id;
const int32_t q_tile_pos_right = q_tile_pos_left + q_tile_token_num;
const auto [kv_tile_pos_left, kv_tile_pos_right] = calcu_kv_tile_pos(
kv_start_pos, kv_end_pos, q_tile_pos_left, q_tile_pos_right,
left_sliding_window_size, right_sliding_window_size);
const auto [aligned_kv_tile_pos_left, aligned_kv_tile_pos_right] =
align_kv_tile_pos(kv_tile_pos_left, kv_tile_pos_right,
kv_len_alignment);
int32_t curr_kv_len =
aligned_kv_tile_pos_right - aligned_kv_tile_pos_left;
total_kv_len += curr_kv_len;
}
}
const int64_t kv_len_per_thread =
(((total_kv_len / thread_num) + kv_len_alignment - 1) /
kv_len_alignment) *
kv_len_alignment * (use_gqa ? input.num_heads_kv : input.num_heads_q);
std::vector<AttentionWorkItemGroup> workitems;
std::vector<ReductionWorkItemGroup> reduce_workitems;
workitems.reserve(1024);
reduce_workitems.reserve(1024);
std::vector<int32_t> workitem_num_per_thread(thread_num, 0);
// split tasks
int32_t curr_thread_id = 0;
int64_t remaining_kv_len = kv_len_per_thread;
int32_t cum_split_num = 0;
for (int32_t req_id = 0; req_id < input.num_reqs; ++req_id) {
const int32_t seq_len = input.seq_lens[req_id];
const int32_t q_token_num =
input.query_start_loc[req_id + 1] - input.query_start_loc[req_id];
const int32_t q_start_pos = (casual ? (seq_len - q_token_num) : 0);
const int32_t kv_start_pos = 0;
const int32_t kv_end_pos = seq_len;
int32_t local_split_id = 0;
AttentionWorkItemGroup curr_workitem(req_id, 0, 0, seq_len);
for (int32_t token_id = 0; token_id < q_token_num;
token_id += max_num_q_token_per_iter) {
const int32_t q_tile_token_num =
std::min(max_num_q_token_per_iter, q_token_num - token_id);
const int32_t q_tile_pos_left = q_start_pos + token_id;
const int32_t q_tile_pos_right = q_tile_pos_left + q_tile_token_num;
const auto [kv_tile_pos_left, kv_tile_pos_right] = calcu_kv_tile_pos(
kv_start_pos, kv_end_pos, q_tile_pos_left, q_tile_pos_right,
left_sliding_window_size, right_sliding_window_size);
const auto [aligned_kv_tile_pos_left, aligned_kv_tile_pos_right] =
align_kv_tile_pos(kv_tile_pos_left, kv_tile_pos_right,
kv_len_alignment);
int32_t curr_kv_len =
aligned_kv_tile_pos_right - aligned_kv_tile_pos_left;
int32_t kv_token_pos_start = aligned_kv_tile_pos_left;
while (curr_kv_len > 0) {
if (curr_kv_len <= (remaining_kv_len + min_split_kv_len) ||
curr_thread_id == (thread_num - 1)) {
curr_workitem.q_token_num += q_tile_token_num;
curr_workitem.total_kv_len += curr_kv_len;
remaining_kv_len -= curr_kv_len;
curr_kv_len = 0;
if (remaining_kv_len < 0) {
// stop to accept more workitems
remaining_kv_len -= min_split_kv_len;
}
if (curr_workitem.kv_split_pos_start != 0) {
// got a partial kv spilt, need to create a single workitem
curr_workitem.split_id = cum_split_num;
curr_workitem.local_split_id = local_split_id;
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
++reduce_workitems.back().split_num;
++cum_split_num;
curr_workitem = AttentionWorkItemGroup(
req_id, token_id + max_num_q_token_per_iter, 0, seq_len);
}
break;
}
if (remaining_kv_len < min_split_kv_len &&
(curr_workitem.total_kv_len > 0 ||
workitem_num_per_thread[curr_thread_id] > 0)) {
// remaining_kv_len is too short, and have allocated workitems, just
// leave to next thread
if (curr_workitem.total_kv_len > 0) {
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
curr_workitem =
AttentionWorkItemGroup(req_id, token_id, 0, seq_len);
}
// switch to next thread
++curr_thread_id;
remaining_kv_len = kv_len_per_thread;
// retry this iteration
continue;
}
// only split tail splits with q_tile_token_num <=
// split_kv_q_token_num_threshold
if (token_id + max_num_q_token_per_iter < q_token_num ||
q_tile_token_num > split_kv_q_token_num_threshold) {
// if requires a new q tile iteration and already has workitems,
// leave this workitem to next thread
if (curr_workitem.q_token_num % default_tile_token_num == 0 &&
(curr_workitem.total_kv_len > 0 ||
workitem_num_per_thread[curr_thread_id] > 0)) {
if (curr_workitem.total_kv_len > 0) {
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
}
curr_workitem =
AttentionWorkItemGroup(req_id, token_id, 0, seq_len);
// switch to next thread
++curr_thread_id;
remaining_kv_len = kv_len_per_thread;
}
curr_workitem.q_token_num += q_tile_token_num;
curr_workitem.total_kv_len += curr_kv_len;
remaining_kv_len -= curr_kv_len;
curr_kv_len = 0;
break;
}
// split kv
if (curr_workitem.total_kv_len > 0) {
// write back curr workitem
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
}
if (kv_token_pos_start == aligned_kv_tile_pos_left) {
// first split, init the workitem
reduce_workitems.emplace_back(ReductionWorkItemGroup(
req_id, token_id, q_tile_token_num, cum_split_num));
}
int32_t spilt_size =
std::min(std::max(remaining_kv_len, (int64_t)min_split_kv_len),
(int64_t)curr_kv_len);
curr_workitem =
AttentionWorkItemGroup(req_id, token_id, kv_token_pos_start,
kv_token_pos_start + spilt_size);
curr_workitem.q_token_num += q_tile_token_num;
curr_workitem.total_kv_len += spilt_size;
curr_workitem.split_id = cum_split_num;
curr_workitem.local_split_id = local_split_id;
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
++reduce_workitems.back().split_num;
++cum_split_num;
++local_split_id;
kv_token_pos_start += spilt_size;
curr_kv_len -= spilt_size;
curr_workitem = AttentionWorkItemGroup(req_id, token_id,
kv_token_pos_start, seq_len);
// switch to next thread
++curr_thread_id;
remaining_kv_len = kv_len_per_thread;
}
}
if (curr_workitem.total_kv_len > 0) {
// write back curr workitem
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
}
}
int64_t metadata_tensor_size =
sizeof(AttentionMetadata) +
workitems.size() * sizeof(AttentionWorkItemGroup) +
reduce_workitems.size() * sizeof(ReductionWorkItemGroup);
auto options =
torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
torch::Tensor metadata_tensor =
torch::empty({metadata_tensor_size}, options);
AttentionMetadata* metadata_ptr = new (metadata_tensor.data_ptr())
AttentionMetadata(input.isa, workitems.size(), reduce_workitems.size(),
cum_split_num, split_kv_q_token_num_threshold);
AttentionWorkItemGroup* workitem_groups_ptr =
metadata_ptr->workitem_groups_ptr;
ReductionWorkItemGroup* reduction_items_ptr =
metadata_ptr->reduction_items_ptr;
std::memcpy(workitem_groups_ptr, workitems.data(),
workitems.size() * sizeof(AttentionWorkItemGroup));
std::memcpy(reduction_items_ptr, reduce_workitems.data(),
reduce_workitems.size() * sizeof(ReductionWorkItemGroup));
int32_t effective_thread_num = 0;
for (; effective_thread_num < thread_num; ++effective_thread_num) {
if (workitem_num_per_thread[effective_thread_num] == 0) {
break;
}
}
std::memcpy(metadata_ptr->cu_workitem_num_per_thread + 1,
workitem_num_per_thread.data(),
workitem_num_per_thread.size() * sizeof(int32_t));
for (int32_t i = 1; i <= thread_num; ++i) {
metadata_ptr->cu_workitem_num_per_thread[i] +=
metadata_ptr->cu_workitem_num_per_thread[i - 1];
}
metadata_ptr->effective_thread_num = effective_thread_num;
{
// when q_tile_size = max_num_q_per_iter, requires max
// attention_scratchpad_size
AttentionScratchPad sc(0, *metadata_ptr, 0x0);
int64_t n = AttentionScheduler::calcu_tile_size_with_constant_q(
cache_size, input.head_dim, input.elem_size, input.q_buffer_elem_size,
input.logits_buffer_elem_size, input.output_buffer_elem_size,
max_num_q_per_iter, kv_len_alignment, max_num_q_per_iter, true);
sc.update(input.head_dim, input.q_buffer_elem_size,
input.logits_buffer_elem_size, input.output_buffer_elem_size,
max_num_q_per_iter, max_num_q_per_iter, n);
metadata_ptr->attention_scratchpad_size_per_thread =
((sc.get_thread_scratchpad_size() + 63) / 64) * 64;
sc.update(0, metadata_ptr->reduction_split_num, input.head_dim,
q_head_per_kv * split_kv_q_token_num_threshold,
input.output_buffer_elem_size);
metadata_ptr->reduction_scratchpad_size_per_kv_head =
((sc.get_reduction_scratchpad_size() + 63) / 64) * 64;
}
int64_t scratchpad_size =
metadata_ptr->attention_scratchpad_size_per_thread *
metadata_ptr->thread_num +
metadata_ptr->reduction_scratchpad_size_per_kv_head *
(use_gqa ? input.num_heads_kv : input.num_heads_q);
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(
scratchpad_size);
// metadata_ptr->print();
// test out of boundary access
// {
// float* cache_ptr =
// DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<float>();
// for (int64_t i = 0; i < scratchpad_size / sizeof(float); ++i) {
// cache_ptr[i] = std::numeric_limits<float>::quiet_NaN();
// }
// }
return metadata_tensor;
}
FORCE_INLINE static std::pair<int32_t, int32_t> calcu_kv_tile_pos(
int32_t kv_left_pos, int32_t kv_right_pos, int32_t q_left_pos,
int32_t q_right_pos, int32_t sliding_window_left,
int32_t sliding_window_right) {
if (sliding_window_left != -1) {
kv_left_pos = std::max(kv_left_pos, q_left_pos - sliding_window_left);
}
if (sliding_window_right != -1) {
kv_right_pos = std::min(kv_right_pos, q_right_pos + sliding_window_right);
}
return {kv_left_pos, kv_right_pos};
}
FORCE_INLINE static std::pair<int32_t, int32_t> align_kv_tile_pos(
int32_t kv_left_pos, int32_t kv_right_pos, int32_t align_factor) {
kv_left_pos = (kv_left_pos / align_factor) * align_factor;
kv_right_pos =
((kv_right_pos + align_factor - 1) / align_factor) * align_factor;
return {kv_left_pos, kv_right_pos};
}
static int64_t calcu_default_tile_size(int64_t cache_size, int64_t head_dim,
int64_t elem_size,
int64_t q_buffer_elem_size,
int64_t logits_buffer_elem_size,
int64_t output_buffer_elem_size,
int64_t max_num_q_per_iter,
int64_t round_size) {
// For CPU, different from CUDA, Q@K^T results should also be hold in cache,
// using float32. Intermediate outputs should be float32 to be compatible
// with AMX Then the cache includes:
// - Q: q_tile_size * head_dim * q_buffer_elem_size
// - K, V: 2 * k_tile_size * head_dim * elem_size
// - Q@K^T: max_num_q_per_iter * k_tile_size * logits_buffer_elem_size
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size
// By default, let tile_size = q_tile_size = k_tile_size. To record
// is_first_iter states in a static array, require the default tile <= 128 *
// max_num_q_per_iter
int64_t tile_size =
cache_size / (head_dim * (q_buffer_elem_size + 2 * elem_size +
output_buffer_elem_size) +
max_num_q_per_iter * logits_buffer_elem_size);
tile_size = std::min(tile_size, MaxQTileIterNum * max_num_q_per_iter);
int64_t rounded_tile_size = (tile_size / round_size) * round_size;
return std::max(rounded_tile_size, round_size);
}
static int64_t calcu_tile_size_with_constant_q(
int64_t cache_size, int64_t head_dim, int64_t elem_size,
int64_t q_buffer_elem_size, int64_t logits_buffer_elem_size,
int64_t output_buffer_elem_size, int64_t max_num_q_per_iter,
int64_t round_size, int64_t q_tile_size, bool one_round) {
// calculate tile_size with known q_tile_size
// If one_round is True, the outer Q tile loop time is 1, then the K,V will
// not be included in the cache
int64_t tile_size;
if (one_round) {
tile_size =
(cache_size - q_tile_size * head_dim *
(q_buffer_elem_size + output_buffer_elem_size)) /
(logits_buffer_elem_size * max_num_q_per_iter);
} else {
tile_size =
(cache_size - q_tile_size * head_dim *
(q_buffer_elem_size + output_buffer_elem_size)) /
(logits_buffer_elem_size * max_num_q_per_iter +
2 * head_dim * elem_size);
}
int64_t rounded_tile_size = (tile_size / round_size) * round_size;
return std::max(rounded_tile_size, round_size);
}
static int64_t get_available_l2_size() {
static int64_t size = []() {
#if defined(__APPLE__)
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t l2_cache_size = 0;
size_t len = sizeof(l2_cache_size);
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
l2_cache_size > 0) {
return l2_cache_size >> 1; // use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return 128LL * 1024 >> 1; // use 50% of 128KB
#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
TORCH_CHECK_NE(l2_cache_size, -1);
return l2_cache_size >> 1; // use 50% of L2 cache
#endif
}();
return size;
}
private:
int64_t available_cache_size_;
};
struct AttentionInput {
AttentionMetadata* metadata;
int32_t num_tokens;
int32_t num_heads;
int32_t num_kv_heads;
int32_t block_size;
void* query;
int64_t query_num_tokens_stride;
int64_t query_num_heads_stride;
int64_t cache_num_blocks_stride;
int64_t cache_num_kv_heads_stride;
int64_t blt_num_tokens_stride;
void* key_cache;
void* value_cache;
void* output;
int32_t* query_start_loc;
int32_t* seq_lens;
int32_t* block_table;
float* alibi_slopes;
c10::BFloat16* s_aux;
float scale;
bool causal;
int32_t sliding_window_left;
int32_t sliding_window_right;
float softcap;
};
#define DEFINE_CPU_ATTENTION_PARAMS \
q_buffer_t *__restrict__ q_heads_buffer, \
kv_cache_t *__restrict__ k_head_cache_ptr, \
kv_cache_t *__restrict__ v_head_cache_ptr, \
logits_buffer_t *__restrict__ logits_buffer, \
float *__restrict__ partial_q_buffer, float *__restrict__ max_buffer, \
float *__restrict__ sum_buffer, int32_t *__restrict__ block_table, \
const int32_t kv_tile_start_pos, const int32_t kv_tile_end_pos, \
const int32_t kv_tile_token_num, \
const int64_t kv_cache_num_blocks_stride, const int32_t q_head_num, \
const int32_t q_token_num, const int32_t q_tile_start_pos, \
const int32_t q_heads_per_kv, const int32_t block_size, \
const int32_t left_window_size, const int32_t right_window_size, \
float scale, const float softcap_scale, \
const float *__restrict__ alibi_slopes, const bool is_first_iter, \
const bool use_sink, const bool debug_info
#define CPU_ATTENTION_PARAMS \
q_heads_buffer, k_head_cache_ptr, v_head_cache_ptr, logits_buffer, \
partial_q_buffer, max_buffer, sum_buffer, block_table, \
kv_tile_start_pos, kv_tile_end_pos, kv_tile_token_num, \
kv_cache_num_blocks_stride, q_head_num, q_token_num, q_tile_start_pos, \
q_heads_per_kv, block_size, left_window_size, right_window_size, scale, \
softcap_scale, alibi_slopes, is_first_iter, use_sink, debug_info
enum class AttentionGemmPhase { QK, PV };
template <typename T>
struct VecTypeTrait {
using vec_t = void;
};
template <>
struct VecTypeTrait<float> {
using vec_t = vec_op::FP32Vec16;
};
// ARM only supports BF16 with ARMv8.6-A extension
#if (defined(__aarch64__) && !defined(ARM_BF16_SUPPORT))
#else
template <>
struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16;
};
#endif
#if !defined(__powerpc__) && !defined(__s390x__)
template <>
struct VecTypeTrait<c10::Half> {
using vec_t = vec_op::FP16Vec16;
};
#endif
template <typename T>
void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
int32_t stride) {
std::stringstream ss;
ss << std::fixed << std::setprecision(5) << name << ": [\n";
auto* curr_logits_buffer = ptr;
for (int32_t m = 0; m < row; ++m) {
for (int32_t n = 0; n < col; ++n) {
ss << curr_logits_buffer[n] << ", ";
}
ss << "\n";
curr_logits_buffer += stride;
}
ss << "]\n";
std::printf("%s", ss.str().c_str());
}
template <typename attention_impl_t>
class AttentionMainLoop {
public:
using query_t = typename attention_impl_t::query_t;
using q_buffer_t = typename attention_impl_t::q_buffer_t;
using kv_cache_t = typename attention_impl_t::kv_cache_t;
using logits_buffer_t = typename attention_impl_t::logits_buffer_t;
using partial_output_buffer_t =
typename attention_impl_t::partial_output_buffer_t;
using prob_buffer_t = typename attention_impl_t::prob_buffer_t;
static constexpr int64_t max_q_head_num_per_iter =
attention_impl_t::MaxQHeadNumPerIteration;
static constexpr int64_t blocksize_alignment =
attention_impl_t::BlockSizeAlignment;
static constexpr int64_t headdim_alignment =
attention_impl_t::HeadDimAlignment;
static constexpr int64_t head_dim = attention_impl_t::HeadDim;
static constexpr ISA ISAType = attention_impl_t::ISAType;
static constexpr bool scale_on_logits =
attention_impl_t::scale_on_logits; // apply scale on logits, otherwise
// apply scale on q_buffer
template <typename tile_gemm_t>
class Attention {
public:
// Args:
// - q_heads_buffer: [MaxQHeadNumPerIteration, head_dim]
// - k_head_cache_ptr: [num_blocks, block_size * head_dim]
// - v_head_cache_ptr: [num_blocks, block_size * head_dim]
// - logits_buffer: [MaxQHeadNumPerIteration, kv_tile_token_num], store Q@K
// - logits partial_q_buffer: [MaxQHeadNumPerIteration, head_dim], store
// partial output
// - max_buffer: [MaxQHeadNumPerIteration, 1], store max logits
// - sum_buffer: [MaxQHeadNumPerIteration, 1], store sum of exp
// - block_table
// - kv_tile_start_pos: start position of KV cache, aligned to
// BlockSizeAlignment
// - kv_tile_end_pos: end position of KV cache, aligned to
// BlockSizeAlignment
// - kv_tile_token_num: KV token num, aligned to BlockSizeAlignment
// - kv_cache_num_blocks_stride
// - q_head_num: head num of q_tile
// - q_token_num: token num of q_tile, should be q_head_num /
// q_heads_per_kv
// - q_tile_start_pos: start pos of the first token in q_heads_buffer
// - q_heads_per_kv
// - block_size
// - left_window_size
// - right_window_size
// - scale
// - softcap_scale
// - alibi_slopes
// - is_first_iter
// - use_sink
// - debug_info
void operator()(DEFINE_CPU_ATTENTION_PARAMS) {
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
const int64_t k_cache_token_group_stride =
attention_impl_t::k_cache_token_group_stride(block_size);
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
const int64_t v_cache_token_group_stride =
attention_impl_t::v_cache_token_group_stride(block_size);
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
const int64_t v_cache_head_group_stride =
attention_impl_t::v_cache_head_group_stride(block_size);
const int32_t token_group_num = kv_tile_token_num / blocksize_alignment;
const int32_t token_group_num_per_block =
block_size / blocksize_alignment;
const int32_t start_block_idx = kv_tile_start_pos / block_size;
const int32_t start_block_offset = kv_tile_start_pos % block_size;
const int32_t start_block_group_offset =
start_block_offset / blocksize_alignment;
const int32_t end_block_idx =
(kv_tile_start_pos + kv_tile_token_num - 1) / block_size + 1;
// compute Q@K logits
{
int32_t curr_group_offset =
start_block_group_offset * k_cache_token_group_stride;
int32_t curr_group_num_in_block =
token_group_num_per_block - start_block_group_offset;
int32_t remaining_group_num = token_group_num;
logits_buffer_t* curr_logits_buffer = logits_buffer;
for (int32_t block_idx = start_block_idx; block_idx < end_block_idx;
++block_idx) {
int32_t physical_block_idx = block_table[block_idx];
kv_cache_t* k_cache_block_ptr =
k_head_cache_ptr +
physical_block_idx * kv_cache_num_blocks_stride +
curr_group_offset;
curr_group_num_in_block =
std::min(remaining_group_num, curr_group_num_in_block);
for (int32_t block_group_idx = 0;
block_group_idx < curr_group_num_in_block; ++block_group_idx) {
// logits_tile = q_tile @ k_tile, [MaxQHeadNumPerIteration,
// BlockSizeAlignment] = [MaxQHeadNumPerIteration, head_dim] @
// [head_dim, BlockSizeAlignment]
// By default, logits_buffer, q_buffer and k_cache are row-major,
// but may be packed by ISA implementation.
tile_gemm_t::template gemm<AttentionGemmPhase::QK, head_dim>(
q_head_num, q_heads_buffer, k_cache_block_ptr,
curr_logits_buffer, head_dim, block_size, kv_tile_token_num,
block_size, head_dim, false);
if constexpr (scale_on_logits) {
float* __restrict__ scale_curr_logits_buffer = curr_logits_buffer;
vec_op::FP32Vec16 scale_vec(scale);
for (int32_t i = 0; i < q_head_num; ++i) {
static_assert(blocksize_alignment % 16 == 0);
constexpr int32_t vec_num = blocksize_alignment / 16;
vec_op::unroll_loop<int32_t, vec_num>([&](int32_t vec_idx) {
vec_op::FP32Vec16 vec(scale_curr_logits_buffer +
vec_idx * 16);
vec = vec * scale_vec;
vec.save(scale_curr_logits_buffer + vec_idx * 16);
});
scale_curr_logits_buffer += kv_tile_token_num;
}
}
// Move buffer ptrs
k_cache_block_ptr += k_cache_token_group_stride;
curr_logits_buffer += blocksize_alignment;
}
// Update
remaining_group_num -= curr_group_num_in_block;
curr_group_offset = 0;
curr_group_num_in_block = token_group_num_per_block;
}
}
// process logits
{
// if (debug_info){
// print_logits("raw logits", logits_buffer, q_head_num,
// kv_tile_token_num, kv_tile_token_num);
// }
if (softcap_scale != 0.0f) {
apply_softcap(logits_buffer, kv_tile_token_num, q_head_num,
kv_tile_token_num, softcap_scale);
// print_logits("softcap raw logits", logits_buffer, q_head_num,
// kv_tile_token_num, kv_tile_token_num);
}
if (alibi_slopes != nullptr) {
apply_alibi_slopes(logits_buffer, alibi_slopes, kv_tile_token_num,
q_tile_start_pos, kv_tile_start_pos, q_token_num,
kv_tile_token_num, q_heads_per_kv);
// print_logits("alibi raw logits", logits_buffer, q_head_num,
// kv_tile_token_num, kv_tile_token_num);
}
apply_mask(logits_buffer, kv_tile_token_num, q_tile_start_pos,
kv_tile_start_pos, kv_tile_end_pos, q_token_num,
q_heads_per_kv, left_window_size, right_window_size);
// if (debug_info){
// print_logits("masked logits", logits_buffer, q_head_num,
// kv_tile_token_num, kv_tile_token_num);
// print_logits("old_max", max_buffer, 1, q_head_num, q_head_num);
// print_logits("old_sum", sum_buffer, 1, q_head_num, q_head_num);
// }
apply_softmax(logits_buffer, partial_q_buffer, max_buffer, sum_buffer,
kv_tile_token_num, q_head_num, kv_tile_token_num,
is_first_iter, use_sink);
// if (debug_info){
// print_logits("softmax logits",
// reinterpret_cast<prob_buffer_t*>(logits_buffer), q_head_num,
// kv_tile_token_num, kv_tile_token_num * sizeof(logits_buffer_t) /
// sizeof(prob_buffer_t));
// print_logits("new_max", max_buffer, 1, q_head_num, q_head_num);
// print_logits("new_sum", sum_buffer, 1, q_head_num, q_head_num);
// }
}
// compute P@V
{
int32_t curr_group_offset =
start_block_group_offset * v_cache_token_group_stride;
int32_t curr_group_num_in_block =
token_group_num_per_block - start_block_group_offset;
int32_t remaining_group_num = token_group_num;
int32_t head_dim_group_num = head_dim / headdim_alignment;
prob_buffer_t* curr_prob_buffer =
reinterpret_cast<prob_buffer_t*>(logits_buffer);
int64_t prob_buffer_stride =
kv_tile_token_num *
(sizeof(logits_buffer_t) / sizeof(prob_buffer_t));
partial_output_buffer_t* curr_partial_q_buffer = partial_q_buffer;
bool accum_c = !is_first_iter;
for (int32_t block_idx = start_block_idx; block_idx < end_block_idx;
++block_idx) {
int32_t physical_block_idx = block_table[block_idx];
kv_cache_t* v_cache_block_ptr =
v_head_cache_ptr +
physical_block_idx * kv_cache_num_blocks_stride +
curr_group_offset;
curr_group_num_in_block =
std::min(remaining_group_num, curr_group_num_in_block);
int32_t curr_token_num =
curr_group_num_in_block * blocksize_alignment;
for (int32_t head_dim_group_idx = 0;
head_dim_group_idx < head_dim_group_num; ++head_dim_group_idx) {
// output_tile = p_tile @ v_tile, [MaxQHeadNumPerIteration,
// HeadDimAlignment] = [MaxQHeadNumPerIteration, block_size] @
// [block_size, HeadDimAlignment]
tile_gemm_t::template gemm<AttentionGemmPhase::PV, -1>(
q_head_num, curr_prob_buffer, v_cache_block_ptr,
curr_partial_q_buffer, prob_buffer_stride, head_dim, head_dim,
block_size, curr_token_num, accum_c);
// Update
curr_partial_q_buffer += headdim_alignment;
v_cache_block_ptr += v_cache_head_group_stride;
}
// Update
remaining_group_num -= curr_group_num_in_block;
curr_group_offset = 0;
curr_group_num_in_block = token_group_num_per_block;
curr_prob_buffer += curr_token_num;
curr_partial_q_buffer = partial_q_buffer;
accum_c = true;
}
}
// if (debug_info) {
// print_logits("output", partial_q_buffer, q_head_num, head_dim,
// head_dim);
// }
}
void apply_mask(logits_buffer_t* __restrict__ logits_buffer,
const int64_t logits_buffer_stride,
const int32_t q_tile_start_pos,
const int32_t kv_tile_start_pos,
const int32_t kv_tile_end_pos, const int32_t q_token_num,
const int32_t q_heads_per_kv,
const int32_t sliding_window_left,
const int32_t sliding_window_right) {
// Apply mask
constexpr logits_buffer_t neg_inf =
-std::numeric_limits<logits_buffer_t>::infinity();
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
int32_t curr_token_pos = q_tile_start_pos;
for (int32_t token_idx = 0; token_idx < q_token_num; ++token_idx) {
int32_t left_kv_pos = [&]() {
int32_t pos = kv_tile_start_pos;
if (sliding_window_left != -1) {
pos = std::max(pos, curr_token_pos - sliding_window_left);
}
return pos;
}();
int32_t right_kv_pos = [&]() {
int32_t pos = kv_tile_end_pos;
if (sliding_window_right != -1) {
pos = std::min(pos,
std::max(kv_tile_start_pos,
curr_token_pos + sliding_window_right + 1));
}
return pos;
}();
int32_t left_invalid_token_num = left_kv_pos - kv_tile_start_pos;
int32_t right_invalid_token_num = kv_tile_end_pos - right_kv_pos;
for (int32_t head_idx = 0; head_idx < q_heads_per_kv; ++head_idx) {
logits_buffer_t* __restrict__ curr_logits_buffer_tail =
curr_logits_buffer + right_kv_pos - kv_tile_start_pos;
for (int32_t i = 0; i < left_invalid_token_num; ++i) {
curr_logits_buffer[i] = neg_inf;
}
for (int32_t i = 0; i < right_invalid_token_num; ++i) {
curr_logits_buffer_tail[i] = neg_inf;
}
curr_logits_buffer += logits_buffer_stride;
}
++curr_token_pos;
}
}
void apply_softmax(logits_buffer_t* __restrict__ logits_buffer,
float* __restrict__ partial_q_buffer,
float* __restrict__ max_buffer,
float* __restrict__ sum_buffer,
const int64_t logits_buffer_stride, int32_t q_head_num,
int32_t kv_tile_token_num, bool is_first_iter,
bool use_sink) {
#ifdef DEFINE_FAST_EXP
DEFINE_FAST_EXP
#endif
using prob_buffer_vec_t = typename VecTypeTrait<prob_buffer_t>::vec_t;
static_assert(sizeof(prob_buffer_t) <= sizeof(logits_buffer_t));
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
float* __restrict__ curr_partial_q_buffer = partial_q_buffer;
const int32_t vec_num = kv_tile_token_num / 16;
const int32_t head_vec_num = head_dim / 16;
for (int32_t i = 0; i < q_head_num; ++i) {
float init_max_val = max_buffer[i];
float init_sum_val = sum_buffer[i];
// apply scale and compute max
vec_op::FP32Vec16 max_vec(init_max_val);
{
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
curr_logits_buffer;
for (int32_t j = 0; j < vec_num; ++j) {
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
max_vec = vec.max(max_vec);
curr_logits_buffer_iter += 16;
}
}
float new_max_val = max_vec.reduce_max();
float rescale_factor = init_max_val - new_max_val;
// use same rescale threshold with FA4.
// https://github.com/Dao-AILab/flash-attention/blob/1b8e1e641c6a179be9a0538b7f40fd595050b735/flash_attn/cute/flash_fwd_sm100.py#L1271
bool need_rescale = rescale_factor < -8.0;
if (!need_rescale) {
new_max_val = init_max_val;
} else {
max_buffer[i] = new_max_val;
}
// sub max, compute exp and sum
max_vec = vec_op::FP32Vec16(new_max_val);
vec_op::FP32Vec16 sum_vec(0.0);
{
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
curr_logits_buffer;
prob_buffer_t* __restrict__ curr_prob_buffer_iter =
reinterpret_cast<prob_buffer_t*>(curr_logits_buffer);
for (int32_t j = 0; j < vec_num; ++j) {
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
vec = vec - max_vec;
// compute exp
#ifdef DEFINE_FAST_EXP
vec = fast_exp(vec);
prob_buffer_vec_t output_vec(vec);
output_vec.save(curr_prob_buffer_iter);
#else
vec.save(curr_logits_buffer_iter);
for (int32_t k = 0; k < 16; ++k) {
curr_logits_buffer_iter[k] = std::exp(curr_logits_buffer_iter[k]);
}
vec = vec_op::FP32Vec16(curr_logits_buffer_iter);
#endif
sum_vec = sum_vec + vec;
curr_logits_buffer_iter += 16;
curr_prob_buffer_iter += 16;
}
}
float new_sum_val = sum_vec.reduce_sum();
// rescale sum and partial outputs
if (need_rescale) {
// compute rescale factor
rescale_factor = std::exp(rescale_factor);
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
// rescale sum
new_sum_val += rescale_factor * init_sum_val;
// rescale output
if (!is_first_iter) {
float* __restrict__ curr_partial_q_buffer_iter =
curr_partial_q_buffer;
for (int32_t j = 0; j < head_vec_num; ++j) {
vec_op::FP32Vec16 vec(curr_partial_q_buffer_iter);
vec = vec * rescale_factor_vec;
vec.save(curr_partial_q_buffer_iter);
curr_partial_q_buffer_iter += 16;
}
}
} else {
new_sum_val += init_sum_val;
}
sum_buffer[i] = new_sum_val;
curr_logits_buffer += logits_buffer_stride;
curr_partial_q_buffer += head_dim;
}
}
void apply_softcap(logits_buffer_t* __restrict__ logits_buffer,
const int64_t logits_buffer_stride, int32_t q_head_num,
int32_t kv_tile_token_num, float softcap_scale) {
#ifdef DEFINE_FAST_EXP
DEFINE_FAST_EXP
#endif
float inv_softcap_scale = 1.0 / softcap_scale;
vec_op::FP32Vec16 softcap_scale_vec(softcap_scale);
vec_op::FP32Vec16 inv_softcap_scale_vec(inv_softcap_scale);
vec_op::FP32Vec16 ones_vec(1.0);
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
const int32_t vec_num = kv_tile_token_num / 16;
for (int32_t i = 0; i < q_head_num; ++i) {
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
curr_logits_buffer;
for (int32_t j = 0; j < vec_num; ++j) {
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
vec = vec * inv_softcap_scale_vec;
#ifdef DEFINE_FAST_EXP
vec = fast_exp(vec);
vec_op::FP32Vec16 inv_vec = ones_vec / vec;
vec = (vec - inv_vec) / (vec + inv_vec);
#else
vec.save(curr_logits_buffer_iter);
for (int k = 0; k < 16; ++k) {
curr_logits_buffer_iter[k] = std::tanh(curr_logits_buffer_iter[k]);
}
vec = vec_op::FP32Vec16(curr_logits_buffer_iter);
#endif
vec = vec * softcap_scale_vec;
vec.save(curr_logits_buffer_iter);
curr_logits_buffer_iter += 16;
}
curr_logits_buffer += logits_buffer_stride;
}
}
void apply_alibi_slopes(logits_buffer_t* __restrict__ logits_buffer,
const float* __restrict__ alibi_slopes,
const int64_t logits_buffer_stride,
const int32_t q_tile_start_pos,
const int32_t kv_tile_start_pos,
const int32_t q_token_num,
const int32_t kv_tile_token_num,
const int32_t q_heads_per_kv) {
alignas(64) constexpr float initial_arange_vals[16] = {
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f};
const int32_t vec_num = kv_tile_token_num / 16;
vec_op::FP32Vec16 initial_arange_vals_vec(initial_arange_vals);
initial_arange_vals_vec =
initial_arange_vals_vec + vec_op::FP32Vec16((float)kv_tile_start_pos);
vec_op::FP32Vec16 pos_offset_vec(16.0);
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
for (int32_t i = 0; i < q_token_num; ++i) {
vec_op::FP32Vec16 curr_q_pos_vec((float)(i + q_tile_start_pos));
for (int32_t j = 0; j < q_heads_per_kv; ++j) {
vec_op::FP32Vec16 alibi_scale_vec(alibi_slopes[j]);
vec_op::FP32Vec16 curr_kv_pos_vec(initial_arange_vals_vec);
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
curr_logits_buffer;
for (int32_t k = 0; k < vec_num; ++k) {
vec_op::FP32Vec16 alibi_bias_vec =
alibi_scale_vec * (curr_kv_pos_vec - curr_q_pos_vec);
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
vec = vec + alibi_bias_vec;
vec.save(curr_logits_buffer_iter);
curr_kv_pos_vec = curr_kv_pos_vec + pos_offset_vec;
curr_logits_buffer_iter += 16;
}
curr_logits_buffer += logits_buffer_stride;
}
}
}
};
public:
void operator()(const AttentionInput* input) {
const int thread_num = omp_get_max_threads();
TORCH_CHECK_EQ(input->metadata->thread_num, thread_num);
std::atomic<int32_t> guard_counter(0);
std::atomic<int32_t>* guard_counter_ptr = &guard_counter;
#pragma omp parallel for schedule(static, 1)
for (int thread_id = 0; thread_id < thread_num; ++thread_id) {
AttentionMetadata& metadata = *input->metadata;
if (metadata.workitem_group_num == 0) {
continue;
}
attention_impl_t attn_impl;
// general information
const int32_t q_head_num = input->num_heads;
const int32_t kv_head_num = input->num_kv_heads;
const int32_t q_heads_per_kv = q_head_num / kv_head_num;
const bool use_gqa =
(max_q_head_num_per_iter % q_heads_per_kv == 0) ? true : false;
const int32_t actual_kv_head_num = use_gqa ? kv_head_num : q_head_num;
const int32_t actual_q_heads_per_kv = use_gqa ? q_heads_per_kv : 1;
TORCH_CHECK_LE(actual_q_heads_per_kv, max_q_head_num_per_iter);
const int32_t max_q_token_num_per_iter =
max_q_head_num_per_iter / actual_q_heads_per_kv;
const int64_t q_token_num_stride = input->query_num_tokens_stride;
const int64_t q_head_num_stride = input->query_num_heads_stride;
const int64_t kv_cache_head_num_stride = input->cache_num_kv_heads_stride;
const int64_t kv_cache_block_num_stride = input->cache_num_blocks_stride;
const int32_t sliding_window_left = input->sliding_window_left;
const int32_t sliding_window_right = input->sliding_window_right;
const int32_t block_size = input->block_size;
const float scale = input->scale;
const float softcap_scale = input->softcap;
const float* alibi_slopes = input->alibi_slopes;
const c10::BFloat16* s_aux = input->s_aux;
const bool casual = input->causal;
int32_t* const block_table = input->block_table;
const int64_t block_table_stride = input->blt_num_tokens_stride;
// init buffers
void* scratchpad_ptr =
DNNLScratchPadManager::get_dnnl_scratchpad_manager()
->get_data<void>();
AttentionScratchPad buffer_manager(thread_id, metadata, scratchpad_ptr);
const int32_t total_reduction_split_num = metadata.reduction_split_num;
if (metadata.reduction_split_num > 0) {
// reset split flag
for (int32_t head_idx = thread_id; head_idx < actual_kv_head_num;
head_idx += thread_num) {
buffer_manager.update(head_idx, total_reduction_split_num, head_dim,
0, sizeof(partial_output_buffer_t));
volatile bool* __restrict__ curr_flag_ptr =
buffer_manager.get_reduce_flag_buffer();
for (int32_t split_idx = 0; split_idx < total_reduction_split_num;
++split_idx) {
curr_flag_ptr[split_idx] = false;
}
}
}
const int64_t available_cache_size =
AttentionScheduler::get_available_l2_size();
const int32_t default_tile_size =
AttentionScheduler::calcu_default_tile_size(
available_cache_size, head_dim, sizeof(kv_cache_t),
sizeof(q_buffer_t), sizeof(logits_buffer_t),
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
max_q_head_num_per_iter);
const int32_t default_q_tile_token_num =
default_tile_size / actual_q_heads_per_kv;
AttentionWorkItemGroup* const workitem_groups =
metadata.workitem_groups_ptr;
const int32_t* cu_workitem_num_per_thread =
metadata.cu_workitem_num_per_thread;
ReductionWorkItemGroup* const reduction_items =
metadata.reduction_items_ptr;
const int32_t effective_thread_num = metadata.effective_thread_num;
const int32_t reduction_item_num = metadata.reduction_item_num;
const int32_t split_kv_q_token_num_threshold =
metadata.split_kv_q_token_num_threshold;
const int32_t workitem_groups_counter_num =
actual_kv_head_num * effective_thread_num;
const int32_t reduction_items_counter_num =
actual_kv_head_num * reduction_item_num;
const int32_t total_counter_num =
workitem_groups_counter_num + reduction_items_counter_num;
if (metadata.reduction_split_num > 0) {
++(*guard_counter_ptr);
while (guard_counter_ptr->load() != thread_num) {
#ifdef FAST_SPINNING
FAST_SPINNING
#else
std::this_thread::yield();
#endif
}
}
// main loop
for (;;) {
int64_t task_idx = metadata.acquire_counter();
if (task_idx >= total_counter_num) {
// no more tasks, leave loop
break;
}
if (task_idx < workitem_groups_counter_num) {
// attention task
// map task_idx to workitem_groups
const int32_t kv_head_idx = task_idx / effective_thread_num;
const int32_t thread_offset = task_idx % effective_thread_num;
AttentionWorkItemGroup* const curr_workitem_groups =
workitem_groups + cu_workitem_num_per_thread[thread_offset];
const int32_t curr_workitem_groups_num =
cu_workitem_num_per_thread[thread_offset + 1] -
cu_workitem_num_per_thread[thread_offset];
const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv;
for (int32_t workitem_group_idx = 0;
workitem_group_idx < curr_workitem_groups_num;
++workitem_group_idx) {
AttentionWorkItemGroup* const current_workitem_group =
&curr_workitem_groups[workitem_group_idx];
const int32_t current_group_idx = current_workitem_group->req_id;
const int32_t kv_start_pos =
current_workitem_group->kv_split_pos_start;
const int32_t kv_end_pos = current_workitem_group->kv_split_pos_end;
const int32_t curr_spilt_id = current_workitem_group->split_id;
const int32_t q_token_id_start =
current_workitem_group->q_token_id_start;
const int32_t q_token_num = current_workitem_group->q_token_num;
// taskgroup general information
const int32_t q_end = input->query_start_loc[current_group_idx + 1];
const int32_t q_start = input->query_start_loc[current_group_idx];
const int32_t seq_len = input->seq_lens[current_group_idx];
const int32_t q_start_pos =
(casual ? seq_len - (q_end - q_start) : 0);
const int32_t block_num = (seq_len + block_size - 1) / block_size;
// Only apply sink for the first KV split
bool use_sink = (s_aux != nullptr &&
current_workitem_group->local_split_id == 0);
for (int32_t q_token_offset = 0; q_token_offset < q_token_num;
q_token_offset += default_q_tile_token_num) {
bool first_iter_flag[AttentionScheduler::MaxQTileIterNum];
for (int32_t i = 0; i < AttentionScheduler::MaxQTileIterNum;
++i) {
first_iter_flag[i] = true;
}
const int32_t q_token_start_idx =
q_start + q_token_offset + q_token_id_start;
const int32_t actual_q_token_num = std::min(
default_q_tile_token_num, q_token_num - q_token_offset);
const int32_t q_head_tile_size =
actual_q_token_num * actual_q_heads_per_kv;
const int32_t rounded_q_head_tile_size =
((q_head_tile_size + max_q_head_num_per_iter - 1) /
max_q_head_num_per_iter) *
max_q_head_num_per_iter;
const int32_t kv_tile_size =
AttentionScheduler::calcu_tile_size_with_constant_q(
available_cache_size, head_dim, sizeof(kv_cache_t),
sizeof(q_buffer_t), sizeof(logits_buffer_t),
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
blocksize_alignment, rounded_q_head_tile_size,
rounded_q_head_tile_size <= max_q_head_num_per_iter);
// update buffers
buffer_manager.update(
head_dim, sizeof(q_buffer_t), sizeof(logits_buffer_t),
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
rounded_q_head_tile_size, kv_tile_size);
q_buffer_t* q_buffer = buffer_manager.get_q_buffer<q_buffer_t>();
float* logits_buffer = buffer_manager.get_logits_buffer();
float* partial_q_buffer = buffer_manager.get_output_buffer();
float* max_buffer = buffer_manager.get_max_buffer();
float* sum_buffer = buffer_manager.get_sum_buffer();
const int32_t q_tile_start_pos =
q_start_pos + q_token_offset + q_token_id_start;
const int32_t q_tile_end_pos =
q_tile_start_pos + actual_q_token_num;
const auto [kv_tile_start_pos, kv_tile_end_pos] =
AttentionScheduler::calcu_kv_tile_pos(
kv_start_pos, kv_end_pos, q_tile_start_pos,
q_tile_end_pos, sliding_window_left,
sliding_window_right);
const auto [rounded_kv_tile_start_pos, rounded_kv_tile_end_pos] =
AttentionScheduler::align_kv_tile_pos(
kv_tile_start_pos, kv_tile_end_pos, blocksize_alignment);
int32_t curr_kv_head_idx =
use_gqa ? kv_head_idx
: (kv_head_idx /
q_heads_per_kv); // for GQA disabled case
// std::printf("thread_id: %d, req_id: %d, q_token_start: %d,
// q_token_end: %d, q_head_start: %d, q_head_end: %d, kv_head_idx:
// %d, kv_pos_start: %d, kv_pos_end: %d\n",
// thread_id, current_group_idx,
// q_token_start_idx, q_token_start_idx +
// actual_q_token_num, q_head_start_idx,
// q_head_start_idx + actual_q_heads_per_kv,
// curr_kv_head_idx, kv_tile_start_pos,
// kv_tile_end_pos);
// move buffers
kv_cache_t* curr_k_cache =
reinterpret_cast<kv_cache_t*>(input->key_cache) +
curr_kv_head_idx * kv_cache_head_num_stride;
kv_cache_t* curr_v_cache =
reinterpret_cast<kv_cache_t*>(input->value_cache) +
curr_kv_head_idx * kv_cache_head_num_stride;
query_t* const q_tile_ptr =
reinterpret_cast<query_t*>(input->query) +
q_token_start_idx * q_token_num_stride +
q_head_start_idx * q_head_num_stride;
size_t output_buffer_offset =
q_token_start_idx * q_head_num * head_dim +
q_head_start_idx * head_dim;
int32_t* curr_block_table =
block_table + current_group_idx * block_table_stride;
const float* curr_alibi_slopes =
(alibi_slopes != nullptr ? alibi_slopes + q_head_start_idx
: nullptr);
const c10::BFloat16* curr_s_aux =
(s_aux != nullptr ? s_aux + q_head_start_idx : nullptr);
// copy the Q tile to q_buffer, the logical layout of q_buffer is
// [actual_q_token_num, actual_q_heads_per_kv, head_dim]
{
attn_impl.copy_q_heads_tile(
q_tile_ptr, q_buffer, actual_q_token_num,
actual_q_heads_per_kv, q_token_num_stride,
q_head_num_stride, scale);
}
if (use_sink) {
alignas(64) float s_aux_fp32[16];
#if defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
// ARM without native BF16 support: manual conversion
for (int i = 0; i < 16; ++i) {
s_aux_fp32[i] = static_cast<float>(curr_s_aux[i]);
}
#else
// All other platforms have BF16Vec16 available
vec_op::BF16Vec16 vec_bf16(curr_s_aux);
vec_op::FP32Vec16 vec_fp32(vec_bf16);
vec_fp32.save(s_aux_fp32);
#endif
float* __restrict__ curr_sum_buffer = sum_buffer;
float* __restrict__ curr_max_buffer = max_buffer;
for (int32_t token_idx = 0; token_idx < actual_q_token_num;
++token_idx) {
for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv;
++head_idx) {
curr_sum_buffer[head_idx] = 1.0f;
curr_max_buffer[head_idx] = s_aux_fp32[head_idx];
}
curr_sum_buffer += actual_q_heads_per_kv;
curr_max_buffer += actual_q_heads_per_kv;
}
} else {
float* __restrict__ curr_sum_buffer = sum_buffer;
float* __restrict__ curr_max_buffer = max_buffer;
for (int32_t token_idx = 0; token_idx < actual_q_token_num;
++token_idx) {
for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv;
++head_idx) {
curr_sum_buffer[head_idx] = 0.0f;
curr_max_buffer[head_idx] =
std::numeric_limits<float>::lowest();
}
curr_sum_buffer += actual_q_heads_per_kv;
curr_max_buffer += actual_q_heads_per_kv;
}
}
// compute loop
for (int32_t kv_tile_pos = rounded_kv_tile_start_pos;
kv_tile_pos < rounded_kv_tile_end_pos;
kv_tile_pos += kv_tile_size) {
const int32_t kv_tile_pos_left = kv_tile_pos;
const int32_t kv_tile_pos_right = std::min(
kv_tile_pos_left + kv_tile_size, rounded_kv_tile_end_pos);
for (int32_t q_head_tile_token_offset = 0;
q_head_tile_token_offset < actual_q_token_num;
q_head_tile_token_offset += max_q_token_num_per_iter) {
const int32_t q_tile_pos_left =
q_tile_start_pos + q_head_tile_token_offset;
const int32_t q_tile_token_num =
std::min(max_q_token_num_per_iter,
actual_q_token_num - q_head_tile_token_offset);
const int32_t q_tile_head_offset =
q_head_tile_token_offset * actual_q_heads_per_kv;
const int32_t q_tile_head_num =
q_tile_token_num * actual_q_heads_per_kv;
const int32_t q_tile_pos_right =
q_tile_pos_left + q_tile_token_num;
const auto [actual_kv_tile_pos_left,
actual_kv_tile_pos_right] =
AttentionScheduler::calcu_kv_tile_pos(
kv_tile_pos_left, kv_tile_pos_right, q_tile_pos_left,
q_tile_pos_right, sliding_window_left,
sliding_window_right);
const int32_t q_iter_idx =
q_head_tile_token_offset / max_q_token_num_per_iter;
if (actual_kv_tile_pos_right <= actual_kv_tile_pos_left) {
continue;
}
// align kv_pos to blocksize_alignment
const auto [aligned_actual_kv_tile_pos_left,
aligned_actual_kv_tile_pos_right] =
AttentionScheduler::align_kv_tile_pos(
actual_kv_tile_pos_left, actual_kv_tile_pos_right,
blocksize_alignment);
const int32_t actual_kv_token_num =
aligned_actual_kv_tile_pos_right -
aligned_actual_kv_tile_pos_left;
// std::printf("\tq_iter_idx: %d, q_token_start: %d,
// q_token_end: %d, q_token_num: %d, q_head_num: %d,
// q_pos_start: %d, q_pos_end: %d, kv_pos_start: %d,
// kv_pos_end: %d\n",
// q_iter_idx, q_token_start_idx +
// q_head_tile_token_offset, q_token_start_idx +
// q_head_tile_token_offset + q_tile_token_num,
// q_tile_token_num, q_tile_head_num,
// q_tile_pos_left, q_tile_pos_right,
// aligned_actual_kv_tile_pos_left,
// aligned_actual_kv_tile_pos_right);
// Move buffers
q_buffer_t* curr_q_heads_buffer =
q_buffer + q_tile_head_offset * head_dim;
float* curr_partial_q_buffer =
partial_q_buffer + q_tile_head_offset * head_dim;
float* curr_max_buffer = max_buffer + q_tile_head_offset;
float* curr_sum_buffer = sum_buffer + q_tile_head_offset;
bool debug_info = false;
// bool debug_info = (
// q_head_start_idx == 4 &&
// (q_token_start_idx + q_head_tile_token_offset) <=
// 4
// && (q_token_start_idx + q_head_tile_token_offset +
// q_tile_token_num) > 4
// );
// if (debug_info) {
// std::printf("\tq_iter_idx: %d, q_token_start: %d,"
// "q_token_end: %d, q_token_num: %d, q_head_num: %d,"
// "q_pos_start: %d, q_pos_end: %d, kv_pos_start: %d,"
// "kv_pos_end: %d\n",
// q_iter_idx, q_token_start_idx +
// q_head_tile_token_offset, q_token_start_idx
// + q_head_tile_token_offset +
// q_tile_token_num, q_tile_token_num,
// q_tile_head_num, q_tile_pos_left,
// q_tile_pos_right,
// aligned_actual_kv_tile_pos_left,
// aligned_actual_kv_tile_pos_right);
// }
attn_impl.template execute_attention<Attention>(
curr_q_heads_buffer, curr_k_cache, curr_v_cache,
logits_buffer, curr_partial_q_buffer, curr_max_buffer,
curr_sum_buffer, curr_block_table,
aligned_actual_kv_tile_pos_left,
aligned_actual_kv_tile_pos_right, actual_kv_token_num,
kv_cache_block_num_stride, q_tile_head_num,
q_tile_token_num, q_tile_pos_left, actual_q_heads_per_kv,
block_size, sliding_window_left, sliding_window_right,
scale, softcap_scale, curr_alibi_slopes,
first_iter_flag[q_iter_idx], use_sink, debug_info);
first_iter_flag[q_iter_idx] = false;
}
}
// write back partial results to output buffer or reduction buffer
{
if (curr_spilt_id == -1) {
final_output(partial_q_buffer,
reinterpret_cast<query_t*>(input->output) +
output_buffer_offset,
sum_buffer, actual_q_heads_per_kv,
actual_q_token_num, q_head_num);
} else {
const int32_t stride =
actual_q_heads_per_kv * split_kv_q_token_num_threshold;
buffer_manager.update(kv_head_idx, total_reduction_split_num,
head_dim, stride, sizeof(float));
volatile bool* split_flag_buffer =
buffer_manager.get_reduce_flag_buffer() + curr_spilt_id;
float* split_output_buffer =
buffer_manager.get_reduce_output_buffer() +
curr_spilt_id * stride * head_dim;
float* split_max_buffer =
buffer_manager.get_reduce_max_buffer() +
curr_spilt_id * stride;
float* split_sum_buffer =
buffer_manager.get_reduce_sum_buffer() +
curr_spilt_id * stride;
partial_output(partial_q_buffer, max_buffer, sum_buffer,
q_head_tile_size, split_output_buffer,
split_max_buffer, split_sum_buffer,
split_flag_buffer);
}
}
}
}
} else {
task_idx -= workitem_groups_counter_num;
const int32_t kv_head_idx = task_idx / reduction_item_num;
const int32_t item_offset = task_idx % reduction_item_num;
ReductionWorkItemGroup* const curr_workitem_groups =
reduction_items + item_offset;
const int32_t curr_output_token_idx =
curr_workitem_groups->q_token_id_start;
const int32_t curr_output_token_num =
curr_workitem_groups->q_token_id_num;
const int32_t curr_split_id = curr_workitem_groups->split_start_id;
const int32_t curr_split_num = curr_workitem_groups->split_num;
const int32_t current_group_idx = curr_workitem_groups->req_id;
const int32_t curr_output_head_num =
curr_output_token_num * actual_q_heads_per_kv;
const int32_t q_start = input->query_start_loc[current_group_idx];
const int32_t q_token_start_idx = q_start + curr_output_token_idx;
const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv;
size_t output_buffer_offset =
q_token_start_idx * q_head_num * head_dim +
q_head_start_idx * head_dim;
const int32_t stride =
actual_q_heads_per_kv * split_kv_q_token_num_threshold;
buffer_manager.update(kv_head_idx, total_reduction_split_num,
head_dim, stride, sizeof(float));
volatile bool* split_flag_buffer =
buffer_manager.get_reduce_flag_buffer() + curr_split_id;
float* split_output_buffer =
buffer_manager.get_reduce_output_buffer() +
curr_split_id * stride * head_dim;
float* split_max_buffer =
buffer_manager.get_reduce_max_buffer() + curr_split_id * stride;
float* split_sum_buffer =
buffer_manager.get_reduce_sum_buffer() + curr_split_id * stride;
reduce_splits(split_output_buffer, split_max_buffer, split_sum_buffer,
split_flag_buffer, stride, curr_output_head_num,
curr_split_num);
final_output(
split_output_buffer,
reinterpret_cast<query_t*>(input->output) + output_buffer_offset,
split_sum_buffer, actual_q_heads_per_kv, curr_output_token_num,
q_head_num);
}
}
}
// Reset counter for next call
input->metadata->reset_counter();
}
void reduce_splits(float* __restrict__ split_output_buffer,
float* __restrict__ split_max_buffer,
float* __restrict__ split_sum_buffer,
volatile bool* __restrict__ flags,
const int32_t head_num_per_split,
const int32_t curr_head_num, const int32_t split_num) {
#ifdef DEFINE_FAST_EXP
DEFINE_FAST_EXP
#endif
// restrict curr_head_num <= 16 in the scheduler
// elems in split_max_buffer, split_sum_buffer are not cache alignment, use
// local buffers to reduce false-sharing
alignas(64) float local_max[16];
alignas(64) float local_sum[16];
float* __restrict__ curr_split_output_buffer = split_output_buffer;
float* __restrict__ curr_split_max_buffer = split_max_buffer;
float* __restrict__ curr_split_sum_buffer = split_sum_buffer;
constexpr int32_t head_dim_group_num = head_dim / 16;
for (int32_t split_idx = 0; split_idx < split_num; ++split_idx) {
while (!flags[split_idx]) {
#ifdef FAST_SPINNING
FAST_SPINNING
#else
std::this_thread::yield();
#endif
}
std::atomic_thread_fence(std::memory_order_acquire);
if (split_idx > 0) {
float* __restrict__ curr_output_buffer = split_output_buffer;
float* __restrict__ curr_split_output_buffer_iter =
curr_split_output_buffer;
for (int32_t head_idx = 0; head_idx < curr_head_num; ++head_idx) {
float final_max = local_max[head_idx];
float curr_max = curr_split_max_buffer[head_idx];
float final_sum = local_sum[head_idx];
float curr_sum = curr_split_sum_buffer[head_idx];
float* __restrict__ non_scale_output_iter =
final_max > curr_max ? curr_output_buffer
: curr_split_output_buffer_iter;
float* __restrict__ scale_output_iter =
final_max > curr_max ? curr_split_output_buffer_iter
: curr_output_buffer;
float rescale_factor = final_max > curr_max ? curr_max - final_max
: final_max - curr_max;
rescale_factor = std::exp(rescale_factor);
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
local_sum[head_idx] = final_max > curr_max
? final_sum + rescale_factor * curr_sum
: rescale_factor * final_sum + curr_sum;
final_max = std::max(final_max, curr_max);
local_max[head_idx] = final_max;
for (int32_t i = 0; i < head_dim_group_num; ++i) {
vec_op::FP32Vec16 non_scale_vec(non_scale_output_iter);
vec_op::FP32Vec16 scale_vec(scale_output_iter);
vec_op::FP32Vec16 final_vec =
non_scale_vec + scale_vec * rescale_factor_vec;
final_vec.save(curr_output_buffer);
non_scale_output_iter += 16;
scale_output_iter += 16;
curr_output_buffer += 16;
}
curr_split_output_buffer_iter += head_dim;
}
} else {
vec_op::FP32Vec16 final_max(split_max_buffer);
final_max.save(local_max);
vec_op::FP32Vec16 final_sum(split_sum_buffer);
final_sum.save(local_sum);
}
curr_split_output_buffer += head_num_per_split * head_dim;
curr_split_max_buffer += head_num_per_split;
curr_split_sum_buffer += head_num_per_split;
}
// write back final max and sum
for (int32_t i = 0; i < curr_head_num; ++i) {
split_max_buffer[i] = local_max[i];
split_sum_buffer[i] = local_sum[i];
}
}
void partial_output(float* __restrict__ partial_output_buffer,
float* __restrict__ partial_max_buffer,
float* __restrict__ partial_sum_buffer,
int32_t curr_head_num,
float* __restrict__ split_output_buffer,
float* __restrict__ split_max_buffer,
float* __restrict__ split_sum_buffer,
volatile bool* __restrict__ flag) {
float* __restrict__ curr_partial_output_buffer = partial_output_buffer;
float* __restrict__ curr_split_output_buffer = split_output_buffer;
constexpr int32_t head_dim_group_num = head_dim / 16;
for (int32_t i = 0; i < curr_head_num; ++i) {
split_max_buffer[i] = partial_max_buffer[i];
split_sum_buffer[i] = partial_sum_buffer[i];
for (int32_t j = 0; j < head_dim_group_num; ++j) {
vec_op::FP32Vec16 vec(curr_partial_output_buffer);
vec.save(curr_split_output_buffer);
curr_partial_output_buffer += 16;
curr_split_output_buffer += 16;
}
}
std::atomic_thread_fence(std::memory_order_release);
*flag = true;
}
void final_output(float* __restrict__ partial_q_buffer,
query_t* __restrict__ curr_output_buffer,
float* __restrict__ sum_buffer,
const int32_t q_heads_per_kv,
const int32_t actual_q_token_num,
const int32_t q_head_num) {
// final output
using output_vec_t = typename VecTypeTrait<query_t>::vec_t;
float* __restrict__ curr_partial_output_buffer = partial_q_buffer;
float* __restrict__ curr_sum_buffer = sum_buffer;
constexpr int32_t group_num_per_head = head_dim / 16;
const int32_t partial_q_buffer_stride = q_heads_per_kv * head_dim;
const int32_t output_buffer_stride = q_head_num * head_dim;
for (int32_t token_idx = 0; token_idx < actual_q_token_num; ++token_idx) {
float* __restrict__ curr_partial_output_buffer_iter =
curr_partial_output_buffer;
query_t* __restrict__ curr_output_buffer_iter = curr_output_buffer;
for (int32_t head_idx = 0; head_idx < q_heads_per_kv; ++head_idx) {
vec_op::FP32Vec16 inv_sum_scale_vec(1.0 / *curr_sum_buffer);
for (int32_t i = 0; i < group_num_per_head; ++i) {
vec_op::FP32Vec16 vec(curr_partial_output_buffer_iter);
// divide the final sum val of softmax here
vec = inv_sum_scale_vec * vec;
// cast to query type
output_vec_t output_vec(vec);
output_vec.save(curr_output_buffer_iter);
// update
curr_partial_output_buffer_iter += 16;
curr_output_buffer_iter += 16;
}
// update
curr_sum_buffer += 1;
}
// update
curr_partial_output_buffer += partial_q_buffer_stride;
curr_output_buffer += output_buffer_stride;
}
}
};
} // namespace cpu_attention
#endif