mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 15:04:47 +08:00
[CPU][Perf] Add fast vectorized exp impl from Arm Optimized Routines (#30068)
Signed-off-by: Ubuntu <ubuntu@ip-10-252-30-150.eu-west-1.compute.internal> Signed-off-by: Elham Harirpoush <elham.harirpoush@arm.com> Co-authored-by: Ubuntu <ubuntu@ip-10-252-30-150.eu-west-1.compute.internal>
This commit is contained in:
@@ -1246,14 +1246,8 @@ class AttentionMainLoop {
|
||||
// rescale sum and partial outputs
|
||||
if (need_rescale) {
|
||||
// compute rescale factor
|
||||
#ifdef DEFINE_FAST_EXP
|
||||
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
|
||||
rescale_factor_vec = fast_exp(rescale_factor_vec);
|
||||
rescale_factor = rescale_factor_vec.get_last_elem();
|
||||
#else
|
||||
rescale_factor = std::exp(rescale_factor);
|
||||
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
|
||||
#endif
|
||||
|
||||
// rescale sum
|
||||
new_sum_val += rescale_factor * init_sum_val;
|
||||
@@ -1889,15 +1883,8 @@ class AttentionMainLoop {
|
||||
: curr_output_buffer;
|
||||
float rescale_factor = final_max > curr_max ? curr_max - final_max
|
||||
: final_max - curr_max;
|
||||
|
||||
#ifdef DEFINE_FAST_EXP
|
||||
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
|
||||
rescale_factor_vec = fast_exp(rescale_factor_vec);
|
||||
rescale_factor = rescale_factor_vec.get_last_elem();
|
||||
#else
|
||||
rescale_factor = std::exp(rescale_factor);
|
||||
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
|
||||
#endif
|
||||
|
||||
local_sum[head_idx] = final_max > curr_max
|
||||
? final_sum + rescale_factor * curr_sum
|
||||
|
||||
@@ -60,4 +60,54 @@
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef __aarch64__
|
||||
// Implementation copied from Arm Optimized Routines (expf AdvSIMD)
|
||||
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
|
||||
#include <limits>
|
||||
#define DEFINE_FAST_EXP \
|
||||
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); \
|
||||
const float ln2_hi = 0x1.62e4p-1f; \
|
||||
const float ln2_lo = 0x1.7f7d1cp-20f; \
|
||||
const float c0 = 0x1.0e4020p-7f; \
|
||||
const float c2 = 0x1.555e66p-3f; \
|
||||
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2}; \
|
||||
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000); \
|
||||
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f); \
|
||||
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f); \
|
||||
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f); \
|
||||
const float32x4_t pos_special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); \
|
||||
const float32x4_t neg_special_bound = vnegq_f32(pos_special_bound); \
|
||||
const float32x4_t inf = \
|
||||
vdupq_n_f32(std::numeric_limits<float>::infinity()); \
|
||||
const float32x4_t zero = vdupq_n_f32(0.0f); \
|
||||
auto neon_expf = [&](float32x4_t values) __attribute__((always_inline)) { \
|
||||
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); \
|
||||
float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0); \
|
||||
r = vfmsq_laneq_f32(r, n, ln2_c02, 1); \
|
||||
uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23); \
|
||||
float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias)); \
|
||||
float32x4_t r2 = vmulq_f32(r, r); \
|
||||
float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2); \
|
||||
float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3); \
|
||||
q = vfmaq_f32(q, p, r2); \
|
||||
p = vmulq_f32(c4, r); \
|
||||
float32x4_t poly = vfmaq_f32(p, q, r2); \
|
||||
poly = vfmaq_f32(scale, poly, scale); \
|
||||
const uint32x4_t hi_mask = vcgeq_f32(values, pos_special_bound); \
|
||||
const uint32x4_t lo_mask = vcleq_f32(values, neg_special_bound); \
|
||||
poly = vbslq_f32(hi_mask, inf, poly); \
|
||||
return vbslq_f32(lo_mask, zero, poly); \
|
||||
}; \
|
||||
auto fast_exp = [&](vec_op::FP32Vec16& vec) \
|
||||
__attribute__((always_inline)) { \
|
||||
float32x4x4_t result; \
|
||||
result.val[0] = neon_expf(vec.reg.val[0]); \
|
||||
result.val[1] = neon_expf(vec.reg.val[1]); \
|
||||
result.val[2] = neon_expf(vec.reg.val[2]); \
|
||||
result.val[3] = neon_expf(vec.reg.val[3]); \
|
||||
return vec_op::FP32Vec16(result); \
|
||||
};
|
||||
|
||||
#endif // __aarch64__
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user