diff --git a/src/Native/include/nncase/ntt/arch/riscv64/primitive_ops.h b/src/Native/include/nncase/ntt/arch/riscv64/primitive_ops.h index a837b44905..70554956c6 100644 --- a/src/Native/include/nncase/ntt/arch/riscv64/primitive_ops.h +++ b/src/Native/include/nncase/ntt/arch/riscv64/primitive_ops.h @@ -14,6 +14,8 @@ */ #pragma once #include "../../primitive_ops.h" +#include "nncase/ntt/arch/riscv64/arch_types.h" +#include "nncase/ntt/vector.h" #include "rvv_mathfun.h" #ifdef __riscv_vector @@ -29,6 +31,15 @@ namespace nncase::ntt::ops { kernel(1, 32) kernel(2, 16) kernel(4, 8) kernel(8, 4) #endif +template <> +struct store, + ntt::vector> { + void operator()(ntt::vector &dest, + const ntt::vector &v) const noexcept { + __riscv_vse32_v_f32m1((float *)&dest, v, NTT_VLEN / 32); + } +}; + #define RVV_UNARY_OP(op, dtype, vl, kernel) \ template <> struct op> { \ ntt::vector \ @@ -610,6 +621,16 @@ REGISTER_RVV_UNARY_OP(square, float, square_float32) REGISTER_RVV_KERNEL(TANH_FLOAT32) REGISTER_RVV_UNARY_OP(tanh, float, tanh_float32) +// erf +#define ERF_FLOAT32(lmul, mlen) \ + inline vfloat32m##lmul##_t erf_float32(const vfloat32m##lmul##_t &v, \ + const size_t vl) { \ + return erf_ps(v, vl); \ + } + +REGISTER_RVV_KERNEL(ERF_FLOAT32) +REGISTER_RVV_UNARY_OP(erf, float, erf_float32) + // binary #define RVV_BINARY_OP(op, dtype, vl, kernel) \ template <> struct op, ntt::vector> { \ @@ -761,6 +782,16 @@ REGISTER_RVV_KERNEL(MOD_FLOAT32) REGISTER_RVV_BINARY_OP(mod, float, mod_float32) // min +template <> struct min { + auto operator()(const float &s1, const float &s2) const noexcept { + float ret; + __asm("fmin.s %[ret], %[s1], %[s2];" + : [ret] "=f"(ret) + : [s1] "f"(s1), [s2] "f"(s2)); + return ret; + } +}; + #define MIN_FLOAT32(lmul, mlen) \ inline vfloat32m##lmul##_t min_float32(const vfloat32m##lmul##_t &v1, \ const vfloat32m##lmul##_t &v2, \ @@ -782,6 +813,16 @@ REGISTER_RVV_KERNEL(MIN_FLOAT32) REGISTER_RVV_BINARY_OP(min, float, min_float32) // max +template <> struct max { + auto operator()(const float &s1, const float &s2) const noexcept { + float ret; + __asm("fmax.s %[ret], %[s1], %[s2];" + : [ret] "=f"(ret) + : [s1] "f"(s1), [s2] "f"(s2)); + return ret; + } +}; + #define MAX_FLOAT32(lmul, mlen) \ inline vfloat32m##lmul##_t max_float32(const vfloat32m##lmul##_t &v1, \ const vfloat32m##lmul##_t &v2, \ @@ -969,6 +1010,7 @@ REGISTER_RVV_KERNEL(INNER_PRODUCT_FLOAT32) REGISTER_RVV_INNER_PRODUCT_OP(float, inner_product_float32) // register mul_add kernel +#if 0 #define MUL_ADD_FLOAT32(lmul, mlen) \ inline vfloat32m##lmul##_t mul_add_float32( \ const vfloat32m##lmul##_t &v1, const vfloat32m##lmul##_t &v2, \ @@ -987,6 +1029,26 @@ REGISTER_RVV_INNER_PRODUCT_OP(float, inner_product_float32) const vfloat32m##lmul##_t &v3, const size_t vl) { \ return __riscv_vfmadd_vf_f32m##lmul(v2, s1, v3, vl); \ } +#else +#define MUL_ADD_FLOAT32(lmul, mlen) \ + inline vfloat32m##lmul##_t mul_add_float32( \ + const vfloat32m##lmul##_t &v1, const vfloat32m##lmul##_t &v2, \ + const vfloat32m##lmul##_t &v3, const size_t vl) { \ + return __riscv_vfmacc_vv_f32m##lmul(v3, v1, v2, vl); \ + } \ + \ + inline vfloat32m##lmul##_t mul_add_float32( \ + const vfloat32m##lmul##_t &v1, const float &s2, \ + const vfloat32m##lmul##_t &v3, const size_t vl) { \ + return __riscv_vfmacc_vf_f32m##lmul(v3, s2, v1, vl); \ + } \ + \ + inline vfloat32m##lmul##_t mul_add_float32( \ + const float &s1, const vfloat32m##lmul##_t &v2, \ + const vfloat32m##lmul##_t &v3, const size_t vl) { \ + return __riscv_vfmacc_vf_f32m##lmul(v3, s1, v2, vl); \ + } +#endif REGISTER_RVV_KERNEL(MUL_ADD_FLOAT32) @@ -1029,7 +1091,6 @@ REGISTER_RVV_KERNEL(MUL_ADD_FLOAT32) REGISTER_RVV_MUL_ADD_OP(float, mul_add_float32) -#if 1 template struct mma, ntt::vector, ntt::vector> { @@ -1038,11 +1099,67 @@ struct mma, ntt::vector, const ntt::vector &rhs, const ntt::vector &v3) const noexcept { auto output = v3; - for (size_t k = 0; k < 4; k++) { - output(0) = (k != 0 || AccC) - ? ntt::mul_add(lhs(0, k), rhs(k), output(0)) - : ntt::mul(lhs(0, k), rhs(k)); - } + auto t0 = AccC ? ntt::mul_add(lhs(0, 0), rhs(0), output(0)) + : ntt::mul(lhs(0, 0), rhs(0)); + auto t1 = ntt::mul(lhs(0, 1), rhs(1)); + t0 = ntt::mul_add(lhs(0, 2), rhs(2), t0); + t1 = ntt::mul_add(lhs(0, 3), rhs(3), t1); + output(0) = ntt::add(t0, t1); + return output; + } +}; + +template +struct mma, ntt::vector, + ntt::vector> { + ntt::vector + operator()(const ntt::vector &lhs, + const ntt::vector &rhs, + const ntt::vector &v3) const noexcept { + auto output = v3; + + auto t0 = AccC ? ntt::mul_add(lhs(0, 0), rhs(0), output(0)) + : ntt::mul(lhs(0, 0), rhs(0)); + auto t1 = ntt::mul(lhs(0, 1), rhs(1)); + t0 = ntt::mul_add(lhs(0, 2), rhs(2), t0); + t1 = ntt::mul_add(lhs(0, 3), rhs(3), t1); + + t0 = ntt::mul_add(lhs(0, 4), rhs(4), t0); + t1 = ntt::mul_add(lhs(0, 5), rhs(5), t1); + t0 = ntt::mul_add(lhs(0, 6), rhs(6), t0); + t1 = ntt::mul_add(lhs(0, 7), rhs(7), t1); + + t0 = ntt::mul_add(lhs(0, 8), rhs(8), t0); + t1 = ntt::mul_add(lhs(0, 9), rhs(9), t1); + t0 = ntt::mul_add(lhs(0, 10), rhs(10), t0); + t1 = ntt::mul_add(lhs(0, 11), rhs(11), t1); + + t0 = ntt::mul_add(lhs(0, 12), rhs(12), t0); + t1 = ntt::mul_add(lhs(0, 13), rhs(13), t1); + t0 = ntt::mul_add(lhs(0, 14), rhs(14), t0); + t1 = ntt::mul_add(lhs(0, 15), rhs(15), t1); + + t0 = ntt::mul_add(lhs(0, 16), rhs(16), t0); + t1 = ntt::mul_add(lhs(0, 17), rhs(17), t1); + t0 = ntt::mul_add(lhs(0, 18), rhs(18), t0); + t1 = ntt::mul_add(lhs(0, 19), rhs(19), t1); + + t0 = ntt::mul_add(lhs(0, 20), rhs(20), t0); + t1 = ntt::mul_add(lhs(0, 21), rhs(21), t1); + t0 = ntt::mul_add(lhs(0, 22), rhs(22), t0); + t1 = ntt::mul_add(lhs(0, 23), rhs(23), t1); + + t0 = ntt::mul_add(lhs(0, 24), rhs(24), t0); + t1 = ntt::mul_add(lhs(0, 25), rhs(25), t1); + t0 = ntt::mul_add(lhs(0, 26), rhs(26), t0); + t1 = ntt::mul_add(lhs(0, 27), rhs(27), t1); + + t0 = ntt::mul_add(lhs(0, 28), rhs(28), t0); + t1 = ntt::mul_add(lhs(0, 29), rhs(29), t1); + t0 = ntt::mul_add(lhs(0, 30), rhs(30), t0); + t1 = ntt::mul_add(lhs(0, 31), rhs(31), t1); + + output(0) = ntt::add(t0, t1); return output; } }; @@ -1055,33 +1172,134 @@ struct mma, ntt::vector, const ntt::vector &rhs, const ntt::vector &v3) const noexcept { auto output = v3; - for (size_t k = 0; k < 4; k++) { - output(0) = (k != 0 || AccC) - ? ntt::mul_add(lhs(0, k), rhs(k), output(0)) - : ntt::mul(lhs(0, k), rhs(k)); - } - - for (size_t k = 0; k < 4; k++) { - output(1) = (k != 0 || AccC) - ? ntt::mul_add(lhs(1, k), rhs(k), output(1)) - : ntt::mul(lhs(1, k), rhs(k)); - } - - for (size_t k = 0; k < 4; k++) { - output(2) = (k != 0 || AccC) - ? ntt::mul_add(lhs(2, k), rhs(k), output(2)) - : ntt::mul(lhs(2, k), rhs(k)); - } - - for (size_t k = 0; k < 4; k++) { - output(3) = (k != 0 || AccC) - ? ntt::mul_add(lhs(3, k), rhs(k), output(3)) - : ntt::mul(lhs(3, k), rhs(k)); - } + ntt::fixed_tensor_alike_t, 1, 4> lhs_2d[4]{ + {{lhs(0)}}, + {{lhs(1)}}, + {{lhs(2)}}, + {{lhs(3)}}, + }; + ntt::fixed_tensor_alike_t, 1, 4> output_2d[4]{ + {{v3(0)}}, + {{v3(1)}}, + {{v3(2)}}, + {{v3(3)}}, + }; + + output_2d[0] = ntt::mma(lhs_2d[0], rhs, output_2d[0]); + output_2d[1] = ntt::mma(lhs_2d[1], rhs, output_2d[1]); + output_2d[2] = ntt::mma(lhs_2d[2], rhs, output_2d[2]); + output_2d[3] = ntt::mma(lhs_2d[3], rhs, output_2d[3]); + + output(0) = output_2d[0](0); + output(1) = output_2d[1](0); + output(2) = output_2d[2](0); + output(3) = output_2d[3](0); + + return output; + } +}; + +template +struct mma, ntt::vector, + ntt::vector> { + ntt::vector + operator()(const ntt::vector &lhs, + const ntt::vector &rhs, + const ntt::vector &v3) const noexcept { + auto output = v3; + ntt::fixed_tensor_alike_t, 1, 32> lhs_2d[]{ + {{lhs(0)}}, {{lhs(1)}}, {{lhs(2)}}, {{lhs(3)}}, {{lhs(4)}}, + {{lhs(5)}}, {{lhs(6)}}, {{lhs(7)}}, {{lhs(8)}}, {{lhs(9)}}, + {{lhs(10)}}, {{lhs(11)}}, {{lhs(12)}}, {{lhs(13)}}, {{lhs(14)}}, + {{lhs(15)}}, {{lhs(16)}}, {{lhs(17)}}, {{lhs(18)}}, {{lhs(19)}}, + {{lhs(20)}}, {{lhs(21)}}, {{lhs(22)}}, {{lhs(23)}}, {{lhs(24)}}, + {{lhs(25)}}, {{lhs(26)}}, {{lhs(27)}}, {{lhs(28)}}, {{lhs(29)}}, + {{lhs(30)}}, {{lhs(31)}}}; + + ntt::fixed_tensor_alike_t, 1, 32> output_2d[]{ + {{v3(0)}}, {{v3(1)}}, {{v3(2)}}, {{v3(3)}}, {{v3(4)}}, + {{v3(5)}}, {{v3(6)}}, {{v3(7)}}, {{v3(8)}}, {{v3(9)}}, + {{v3(10)}}, {{v3(11)}}, {{v3(12)}}, {{v3(13)}}, {{v3(14)}}, + {{v3(15)}}, {{v3(16)}}, {{v3(17)}}, {{v3(18)}}, {{v3(19)}}, + {{v3(20)}}, {{v3(21)}}, {{v3(22)}}, {{v3(23)}}, {{v3(24)}}, + {{v3(25)}}, {{v3(26)}}, {{v3(27)}}, {{v3(28)}}, {{v3(29)}}, + {{v3(30)}}, {{v3(31)}}}; + + output_2d[0] = ntt::mma(lhs_2d[0], rhs, output_2d[0]); + output_2d[1] = ntt::mma(lhs_2d[1], rhs, output_2d[1]); + output_2d[2] = ntt::mma(lhs_2d[2], rhs, output_2d[2]); + output_2d[3] = ntt::mma(lhs_2d[3], rhs, output_2d[3]); + output_2d[4] = ntt::mma(lhs_2d[4], rhs, output_2d[4]); + output_2d[5] = ntt::mma(lhs_2d[5], rhs, output_2d[5]); + output_2d[6] = ntt::mma(lhs_2d[6], rhs, output_2d[6]); + output_2d[7] = ntt::mma(lhs_2d[7], rhs, output_2d[7]); + + output_2d[8] = ntt::mma(lhs_2d[8], rhs, output_2d[8]); + output_2d[9] = ntt::mma(lhs_2d[9], rhs, output_2d[9]); + output_2d[10] = ntt::mma(lhs_2d[10], rhs, output_2d[10]); + output_2d[11] = ntt::mma(lhs_2d[11], rhs, output_2d[11]); + output_2d[12] = ntt::mma(lhs_2d[12], rhs, output_2d[12]); + output_2d[13] = ntt::mma(lhs_2d[13], rhs, output_2d[13]); + output_2d[14] = ntt::mma(lhs_2d[14], rhs, output_2d[14]); + output_2d[15] = ntt::mma(lhs_2d[15], rhs, output_2d[15]); + + output_2d[16] = ntt::mma(lhs_2d[16], rhs, output_2d[16]); + output_2d[17] = ntt::mma(lhs_2d[17], rhs, output_2d[17]); + output_2d[18] = ntt::mma(lhs_2d[18], rhs, output_2d[18]); + output_2d[19] = ntt::mma(lhs_2d[19], rhs, output_2d[19]); + output_2d[20] = ntt::mma(lhs_2d[20], rhs, output_2d[20]); + output_2d[21] = ntt::mma(lhs_2d[21], rhs, output_2d[21]); + output_2d[22] = ntt::mma(lhs_2d[22], rhs, output_2d[22]); + output_2d[23] = ntt::mma(lhs_2d[23], rhs, output_2d[23]); + + output_2d[24] = ntt::mma(lhs_2d[24], rhs, output_2d[24]); + output_2d[25] = ntt::mma(lhs_2d[25], rhs, output_2d[25]); + output_2d[26] = ntt::mma(lhs_2d[26], rhs, output_2d[26]); + output_2d[27] = ntt::mma(lhs_2d[27], rhs, output_2d[27]); + output_2d[28] = ntt::mma(lhs_2d[28], rhs, output_2d[28]); + output_2d[29] = ntt::mma(lhs_2d[29], rhs, output_2d[29]); + output_2d[30] = ntt::mma(lhs_2d[30], rhs, output_2d[30]); + output_2d[31] = ntt::mma(lhs_2d[31], rhs, output_2d[31]); + + output(0) = output_2d[0](0); + output(1) = output_2d[1](0); + output(2) = output_2d[2](0); + output(3) = output_2d[3](0); + output(4) = output_2d[4](0); + output(5) = output_2d[5](0); + output(6) = output_2d[6](0); + output(7) = output_2d[7](0); + + output(8) = output_2d[8](0); + output(9) = output_2d[9](0); + output(10) = output_2d[10](0); + output(11) = output_2d[11](0); + output(12) = output_2d[12](0); + output(13) = output_2d[13](0); + output(14) = output_2d[14](0); + output(15) = output_2d[15](0); + + output(16) = output_2d[16](0); + output(17) = output_2d[17](0); + output(18) = output_2d[18](0); + output(19) = output_2d[19](0); + output(20) = output_2d[20](0); + output(21) = output_2d[21](0); + output(22) = output_2d[22](0); + output(23) = output_2d[23](0); + + output(24) = output_2d[24](0); + output(25) = output_2d[25](0); + output(26) = output_2d[26](0); + output(27) = output_2d[27](0); + output(28) = output_2d[28](0); + output(29) = output_2d[29](0); + output(30) = output_2d[30](0); + output(31) = output_2d[31](0); + return output; } }; -#endif // register reduce_sum kernel #define REDUCE_ADD_FLOAT32(lmul, mlen) \ diff --git a/src/Native/include/nncase/ntt/arch/riscv64/rvv_mathfun.h b/src/Native/include/nncase/ntt/arch/riscv64/rvv_mathfun.h index 3596955c83..1cc4d39e4f 100644 --- a/src/Native/include/nncase/ntt/arch/riscv64/rvv_mathfun.h +++ b/src/Native/include/nncase/ntt/arch/riscv64/rvv_mathfun.h @@ -159,41 +159,58 @@ _RVV_FLOAT_EXP_OP(2, 16, 32, 0x7f, 23) _RVV_FLOAT_EXP_OP(4, 8, 32, 0x7f, 23) _RVV_FLOAT_EXP_OP(8, 4, 32, 0x7f, 23) -#define c_minus_cephes_DP1 -0.78515625 -#define c_minus_cephes_DP2 -2.4187564849853515625e-4 -#define c_minus_cephes_DP3 -3.77489497744594108e-8 -#define c_sincof_p0 -1.9515295891E-4 -#define c_sincof_p1 8.3321608736E-3 -#define c_sincof_p2 -1.6666654611E-1 -#define c_coscof_p0 2.443315711809948E-005 -#define c_coscof_p1 -1.388731625493765E-003 -#define c_coscof_p2 4.166664568298827E-002 -#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI +#if 0 +// from glibc 2.40: max_ulp_error = 3 +// e^x -1 = x + 1/2!x^2 + 1/3!x^3 + 1/4!x^4 + 1/5!x^5 + 1/6!x^6 + 1/7!x^7 +#define _RVV_FLOAT_EXPM1F_OP(LMUL, MLEN, TLEN, E, M) \ + static inline vfloat##TLEN##m##LMUL##_t expm1f( \ + vfloat##TLEN##m##LMUL##_t x, size_t vl) { \ + /* Reduce argument to smaller range: \ + Let i = round(x / ln2) \ + and f = x - i * ln2, then f is in [-ln2/2, ln2/2]. \ + exp(x) - 1 = 2^i * (expm1(f) + 1) - 1 \ + where 2^i is exact because i is an integer. */ \ + auto shift = __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.8p23f, vl); \ + auto j = __riscv_vmv_v_v_f32m##LMUL(x, vl); \ + j = __riscv_vfmadd_vf_f##TLEN##m##LMUL(j, 0x1.715476p+0f, shift, vl); \ + j = __riscv_vfsub_vv_f##TLEN##m##LMUL(j, shift, vl); \ + auto f = __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.62e4p-1f, vl); \ + auto c0 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.7f7d1cp-20f, vl); \ + f = __riscv_vfnmsub_vv_f##TLEN##m##LMUL(f, j, x, vl); \ + auto i = __riscv_vfcvt_x_f_v_i##TLEN##m##LMUL(j, vl); \ + f = __riscv_vfnmsac_vv_f##TLEN##m##LMUL(f, j, c0, vl); \ + /* Approximate expm1(f) using polynomial. \ + Taylor expansion for expm1(x) has the form: \ + x + ax^2 + bx^3 + cx^4 .... \ + So we calculate the polynomial P(f) = a + bf + cf^2 + ... \ + and assemble the approximation expm1(f) ~= f + f^2 * P(f). */ \ + auto poly_0 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.fffffep-2, vl); \ + auto poly_1 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.5554aep-3, vl); \ + auto poly_2 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.555736p-5, vl); \ + auto poly_3 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.12287cp-7, vl); \ + auto poly_4 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.6b55a2p-10, vl); \ + auto p = __riscv_vfmadd_vv_f##TLEN##m##LMUL(poly_4, f, poly_3, vl); \ + auto f2 = __riscv_vfmul_vv_f##TLEN##m##LMUL(f, f, vl); \ + p = __riscv_vfmadd_vv_f##TLEN##m##LMUL(p, f, poly_2, vl); \ + p = __riscv_vfmadd_vv_f##TLEN##m##LMUL(p, f, poly_1, vl); \ + p = __riscv_vfmadd_vv_f##TLEN##m##LMUL(p, f, poly_0, vl); \ + p = __riscv_vfmadd_vv_f##TLEN##m##LMUL(p, f2, f, vl); \ + auto u = __riscv_vsll_vx_i##TLEN##m##LMUL(i, 23, vl); \ + u = __riscv_vadd_vx_i32m##LMUL(u, 0x3f800000, vl); \ + /* expm1(x) ~= p * t + (t - 1). */ \ + auto t = \ + __riscv_vreinterpret_v_i##TLEN##m##LMUL##_f##TLEN##m##LMUL(u); \ + auto tmp = __riscv_vfsub_vf_f##TLEN##m##LMUL(t, 1.f, vl); \ + return __riscv_vfmadd_vv_f##TLEN##m##LMUL(p, t, tmp, vl); \ + } + +_RVV_FLOAT_EXPM1F_OP(1, 32, 32, 0x7f, 23) +_RVV_FLOAT_EXPM1F_OP(2, 16, 32, 0x7f, 23) +_RVV_FLOAT_EXPM1F_OP(4, 8, 32, 0x7f, 23) +_RVV_FLOAT_EXPM1F_OP(8, 4, 32, 0x7f, 23) #define c_tanh_tiny 1e-4f #define c_tanh_hi 9.0f -// The monomial coefficients of the numerator polynomial (odd). -#define c_tanh_alpha_1 4.89352455891786e-3f -#define c_tanh_alpha_3 6.37261928875436e-4f -#define c_tanh_alpha_5 1.48572235717979e-5f -#define c_tanh_alpha_7 5.12229709037114e-8f -#define c_tanh_alpha_9 -8.60467152213735e-11f -#define c_tanh_alpha_11 2.00018790482477e-13f -#define c_tanh_alpha_13 -2.76076847742355e-16f -// The monomial coefficients of the denominator polynomial (even). -#define c_tanh_beta_0 4.89352518554385e-3f -#define c_tanh_beta_2 2.26843463243900e-3f -#define c_tanh_beta_4 1.18534705686654e-4f -#define c_tanh_beta_6 1.19825839466702e-6f - -/* -y = p1 * x + p3 * x^3 + p5 * x^5 + p7 * x^7 + p9 * x^9 + p11 * x^11 + p13 * x^13 - = x * (p1 + p3 * x^2 + x^4 * (p5 + p7 * x^2 + x^4 * (p9 + p11 * x^2 + p13 * -x^4))) - -w = p0 + p2 * x^2 + p4 * x^4 + p6 * x^6 - = p0 + p2 * x^2 + x^4 * (p4 + p6 * x^2) -*/ #define _RVV_FLOAT_TANH_OP(LMUL, MLEN, TLEN) \ static inline vfloat##TLEN##m##LMUL##_t tanh_ps( \ vfloat##TLEN##m##LMUL##_t x, size_t vl) { \ @@ -203,39 +220,13 @@ w = p0 + p2 * x^2 + p4 * x^4 + p6 * x^6 /* this range is -/+1.0f in single-precision. */ \ abs = __riscv_vfmin_vf_f##TLEN##m##LMUL(abs, c_tanh_hi, vl); \ \ - /* since the polynomials are odd/even, we need x**2. */ \ - auto x2 = __riscv_vfmul_vv_f##TLEN##m##LMUL(abs, abs, vl); \ + auto q = expm1f(__riscv_vfmul_vf_f##TLEN##m##LMUL(x, 2.f, vl), vl); \ + auto y = __riscv_vfdiv_vv_f##TLEN##m##LMUL( \ + q, __riscv_vfadd_vf_f##TLEN##m##LMUL(q, 2.f, vl), vl); \ \ - /* evaluate the numerator polynomial y, denominator polynomial w. */ \ - auto c0 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(c_tanh_beta_0, vl); \ - auto c1 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(c_tanh_alpha_1, vl); \ - auto c4 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(c_tanh_beta_4, vl); \ - auto c5 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(c_tanh_alpha_5, vl); \ - auto c9 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(c_tanh_alpha_9, vl); \ - auto y1 = __riscv_vmv_v_v_f##TLEN##m##LMUL(x2, vl); \ - auto y2 = __riscv_vmv_v_v_f##TLEN##m##LMUL(x2, vl); \ - auto y3 = __riscv_vmv_v_v_f##TLEN##m##LMUL(x2, vl); \ - auto w1 = __riscv_vmv_v_v_f##TLEN##m##LMUL(x2, vl); \ - auto w2 = __riscv_vmv_v_v_f##TLEN##m##LMUL(x2, vl); \ - y1 = __riscv_vfmadd_vf_f##TLEN##m##LMUL(y1, c_tanh_alpha_11, c9, vl); \ - w1 = __riscv_vfmadd_vf_f##TLEN##m##LMUL(w1, c_tanh_beta_6, c4, vl); \ - auto x4 = __riscv_vfmul_vv_f##TLEN##m##LMUL(x2, x2, vl); \ - y1 = __riscv_vfmacc_vf_f##TLEN##m##LMUL(y1, c_tanh_alpha_13, x4, vl); \ - y2 = __riscv_vfmadd_vf_f##TLEN##m##LMUL(y2, c_tanh_alpha_7, c5, vl); \ - y1 = __riscv_vfmadd_vv_f##TLEN##m##LMUL(y1, x4, y2, vl); \ - w2 = __riscv_vfmadd_vf_f##TLEN##m##LMUL(w2, c_tanh_beta_2, c0, vl); \ - y3 = __riscv_vfmadd_vf_f##TLEN##m##LMUL(y3, c_tanh_alpha_3, c1, vl); \ - auto w = __riscv_vfmadd_vv_f##TLEN##m##LMUL(w1, x4, w2, vl); \ - y1 = __riscv_vfmadd_vv_f##TLEN##m##LMUL(y1, x4, y3, vl); \ - auto z = __riscv_vfsgnj_vv_f##TLEN##m##LMUL(abs, x, vl); \ - w = __riscv_vfrec7_v_f##TLEN##m##LMUL(w, vl); \ - y1 = __riscv_vfmul_vv_f##TLEN##m##LMUL(y1, z, vl); \ auto tiny_mask = \ __riscv_vmfge_vf_f##TLEN##m##LMUL##_b##MLEN(abs, c_tanh_tiny, vl); \ \ - /* divide the numerator by the denominator. */ \ - auto y = __riscv_vfmul_vv_f##TLEN##m##LMUL(y1, w, vl); \ - \ /* when the argument is very small in magnitude it's more accurate to \ * just return it. */ \ y = __riscv_vmerge_vvm_f##TLEN##m##LMUL(x, y, tiny_mask, vl); \ @@ -243,6 +234,149 @@ w = p0 + p2 * x^2 + p4 * x^4 + p6 * x^6 return y; \ } +#else +#define LOG2_INV 0x1.71547652b82fep+0 +#define LOG2_HI 0x1.62e42fefa39efp-1 +#define LOG2_LO 0x1.abc9e3b39803fp-56 +#define _RVV_FLOAT_TANH_OP(LMUL, MLEN, TLEN) \ + static inline vfloat##TLEN##m##LMUL##_t tanh_ps( \ + vfloat##TLEN##m##LMUL##_t v, size_t vl) { \ + constexpr float fp_posZero = 0.0f; \ + constexpr float fp_posOne = 1.f; \ + auto zero = __riscv_vfmv_v_f_f##TLEN##m##LMUL(fp_posZero, vl); \ + auto one = __riscv_vfmv_v_f_f##TLEN##m##LMUL(fp_posOne, vl); \ + /*tanh(x) = sign(x) * tanh(|x|); suffices to work on |x| for the main \ + * part */ \ + auto vx = __riscv_vfsgnj_vf_f##TLEN####m##LMUL(v, 1.f, vl); \ + /* Suffices to clip |x| to 20, which is bigger than 28 log(2) */ \ + vx = __riscv_vfmin_vf_f##TLEN##m##LMUL(vx, 0x1.4p4, vl); \ + \ + /* tanh(x) = (1 - exp(-2x)) / (1 + exp(-2x)); so we compute exp(-2x) \ + */ \ + /* by replacing x by -2x */ \ + vx = __riscv_vfmul_vf_f##TLEN##m##LMUL(vx, -2.f, vl); \ + auto n_flt = __riscv_vfmul_vf_f##TLEN##m##LMUL(vx, LOG2_INV, vl); \ + auto n = __riscv_vfcvt_x_f_v_i##TLEN##m##LMUL(n_flt, vl); \ + n_flt = __riscv_vfcvt_f_x_v_f##TLEN##m##LMUL(n, vl); \ + auto u = __riscv_vadd_vx_i##TLEN##m##LMUL(n, 127, vl); \ + auto r_delta = \ + __riscv_vfnmsac_vf_f##TLEN##m##LMUL(vx, LOG2_HI, n_flt, vl); \ + u = __riscv_vsll_vx_i##TLEN####m##LMUL(u, 23, vl); \ + auto r = \ + __riscv_vfnmsac_vf_f##TLEN##m##LMUL(r_delta, LOG2_LO, n_flt, vl); \ + auto s = \ + __riscv_vreinterpret_v_i##TLEN##m##LMUL##_f##TLEN##m##LMUL(u); \ + auto s_is_small = \ + __riscv_vmsle_vx_i##TLEN##m##LMUL##_b##MLEN(n, -(23 + 1), vl); \ + r_delta = __riscv_vfsub_vv_f##TLEN##m##LMUL(r_delta, r, vl); \ + auto s_head = __riscv_vfmerge_vfm_f##TLEN##m##LMUL(s, fp_posZero, \ + s_is_small, vl); \ + r_delta = \ + __riscv_vfnmsac_vf_f##TLEN##m##LMUL(r_delta, LOG2_LO, n_flt, vl); \ + /* exp(x) = 2^n exp(r'), r' = r + r_delta and thus we compute 1 +/- \ + exp(x) as 1 +/- 2^(n)(1 + r' + (r')^2/2 + r^3 p(r)) (1 +/- s) +/- s(r' \ + + (r')^2/2) +/- s r^3 p(r) To maintain good precision, 1 +/- s and r' \ + + (r')^2/2 are computed to extra precision in a leading term and a \ + correctional term. This leads to representing 1 +/- exp(x) in a \ + leading and correctional term. */ \ + /* 1 +/- s is exact when s is not small */ \ + auto rsq = __riscv_vfmul_vv_f##TLEN##m##LMUL(r, r, vl); \ + auto s_tail = \ + __riscv_vmerge_vvm_f##TLEN##m##LMUL(zero, s, s_is_small, vl); \ + /* s_head + s_tail = s; and 1 +/- s is (1 +/- s_head) +/- s_tail */ \ + /* exp(r') is approximated by 1 + r' + (r')^2/2 + r^3(p_even(r^2) + \ + r*p_odd(r^2)) using r without delta_r sufficies from the third \ + order onwards */ \ + auto rcube = __riscv_vfmul_vv_f##TLEN##m##LMUL(rsq, r, vl); \ + auto c0 = \ + __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.71ddef82f4beep-19, vl); \ + auto c1 = \ + __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.a01a01b32b633p-13, vl); \ + auto c2 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.111111110ef6ap-7, vl); \ + auto c3 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.555555555555ap-3, vl); \ + auto c4 = \ + __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.a019b37a2b3dfp-16, vl); \ + auto c5 = \ + __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.6c16c17a09506p-10, vl); \ + auto c6 = __riscv_vfmv_v_f_f##TLEN##m##LMUL(0x1.5555555553aefp-5, vl); \ + \ + auto p_even = __riscv_vmv_v_v_f##TLEN##m##LMUL(rsq, vl); \ + p_even = __riscv_vfmadd_vf_f##TLEN##m##LMUL( \ + p_even, 0x1.af6eacd796f0bp-26, c0, vl); \ + p_even = __riscv_vfmadd_vv_f##TLEN##m##LMUL(p_even, rsq, c1, vl); \ + p_even = __riscv_vfmadd_vv_f##TLEN##m##LMUL(p_even, rsq, c2, vl); \ + p_even = __riscv_vfmadd_vv_f##TLEN##m##LMUL(p_even, rsq, c3, vl); \ + \ + auto p_odd = __riscv_vmv_v_v_f##TLEN##m##LMUL(rsq, vl); \ + p_odd = __riscv_vfmadd_vf_f##TLEN##m##LMUL( \ + p_odd, 0x1.289788d8bdadfp-22, c4, vl); \ + p_odd = __riscv_vfmadd_vv_f##TLEN##m##LMUL(p_odd, rsq, c5, vl); \ + p_odd = __riscv_vfmadd_vv_f##TLEN##m##LMUL(p_odd, rsq, c6, vl); \ + auto poly = __riscv_vfmadd_vv_f##TLEN##m##LMUL(p_odd, r, p_even, vl); \ + \ + /* r^3 * poly will be r^3(...) \ + we delay this multiplication with r^3 for now */ \ + \ + /* Compute r' + (r')^2/2 extra precisely */ \ + auto r_prime = __riscv_vfmul_vf_f##TLEN##m##LMUL(r, 0x1.0p-1, vl); \ + auto B = __riscv_vfmadd_vv_f##TLEN##m##LMUL(r, r_prime, r, vl); \ + auto b = __riscv_vfsub_vv_f##TLEN##m##LMUL(r, B, vl); \ + b = __riscv_vfmacc_vv_f##TLEN##m##LMUL(b, r, r_prime, vl); \ + /* B + b is r' + (r')^2/2 extra precisely */ \ + /* incoporate r_delta in R + R^2/2 */ \ + auto c = __riscv_vfmadd_vv_f##TLEN##m##LMUL(r, r_delta, r_delta, vl); \ + b = __riscv_vfadd_vv_f##TLEN##m##LMUL(b, c, vl); \ + poly = __riscv_vfmadd_vv_f##TLEN##m##LMUL(poly, rcube, b, vl); \ + /* B + poly is r' + (r')^2/2 + r^3(.....) */ \ + /* and exp(r') is well approximated by s*(1 + B + poly) */ \ + \ + /* We compute the denominator 1 + exp(R) first as \ + we will need to recipricate afterwards, the latency of which \ + can be hidden somewhat by proceeding with the numerator \ + at that time */ \ + auto Z = __riscv_vfadd_vf_f##TLEN##m##LMUL(s_head, fp_posOne, vl); \ + auto D_tmp = __riscv_vfmadd_vv_f##TLEN##m##LMUL(B, s, Z, vl); \ + auto d_tmp = __riscv_vfsub_vv_f##TLEN##m##LMUL(Z, D_tmp, vl); \ + d_tmp = __riscv_vfmacc_vv_f##TLEN##m##LMUL(d_tmp, s, B, vl); \ + d_tmp = __riscv_vfadd_vv_f##TLEN##m##LMUL(d_tmp, s_tail, vl); \ + d_tmp = __riscv_vfmacc_vv_f##TLEN##m##LMUL(d_tmp, s, poly, vl); \ + /* D_tmp + d_tmp is 1 + exp(R) to high precision, but we have to \ + normalize this representation so that the leading term \ + has full FP64 precision of this sum */ \ + auto D = __riscv_vfadd_vv_f##TLEN##m##LMUL(D_tmp, d_tmp, vl); \ + auto d = __riscv_vfsub_vv_f##TLEN##m##LMUL(D_tmp, D, vl); \ + d = __riscv_vfadd_vv_f##TLEN##m##LMUL(d, d_tmp, vl); \ + \ + /* Now start to compute 1/(D+d) as E + e */ \ + auto E = __riscv_vfrdiv_vf_f##TLEN##m##LMUL(D, fp_posOne, vl); \ + auto e = __riscv_vfnmsub_vv_f##TLEN##m##LMUL(E, D, one, vl); \ + e = __riscv_vfnmsac_vv_f##TLEN##m##LMUL(e, E, d, vl); \ + e = __riscv_vfmul_vv_f##TLEN##m##LMUL( \ + e, __riscv_vfrec7_v_f##TLEN##m##LMUL(D, vl), vl); \ + /* E + e is 1/(D+d) to extra precision */ \ + \ + /* Overlap much of the 1/(D+d) computation with */ \ + /* computing 1 - s(1 + B + poly) */ \ + Z = __riscv_vfrsub_vf_f##TLEN##m##LMUL(s_head, fp_posOne, vl); \ + \ + auto Numer = __riscv_vfnmsub_vv_f##TLEN##m##LMUL(B, s, Z, vl); \ + auto numer = __riscv_vfsub_vv_f##TLEN##m##LMUL(Z, Numer, vl); \ + numer = __riscv_vfnmsac_vv_f##TLEN##m##LMUL(numer, s, B, vl); \ + \ + /* Numer + numer = Z - s * B accurately */ \ + numer = __riscv_vfsub_vv_f##TLEN##m##LMUL(numer, s_tail, vl); \ + numer = __riscv_vfnmsac_vv_f##TLEN##m##LMUL(numer, s, poly, vl); \ + \ + /* (Numer + numer) * (E + e) */ \ + /* Numer * E + ( numer * E + (Numer * e + (e*numer)) ) */ \ + auto vy = __riscv_vfmul_vv_f##TLEN##m##LMUL(e, numer, vl); \ + vy = __riscv_vfmacc_vv_f##TLEN##m##LMUL(vy, Numer, e, vl); \ + vy = __riscv_vfmacc_vv_f##TLEN##m##LMUL(vy, numer, E, vl); \ + vy = __riscv_vfmacc_vv_f##TLEN##m##LMUL(vy, Numer, E, vl); \ + return __riscv_vfsgnj_vv_f##TLEN####m##LMUL(vy, v, vl); \ + } +#endif + _RVV_FLOAT_TANH_OP(1, 32, 32) _RVV_FLOAT_TANH_OP(2, 16, 32) _RVV_FLOAT_TANH_OP(4, 8, 32) @@ -261,4 +395,331 @@ _RVV_FLOAT_POW_OP(2, 16, 32) _RVV_FLOAT_POW_OP(4, 8, 32) _RVV_FLOAT_POW_OP(8, 4, 32) +struct sv_erff_data { + float erf[513]; + float scale[513]; +}; + +/* Lookup table used in SVE erff. + For each possible rounded input r (multiples of 1/128), between + r = 0.0 and r = 4.0 (513 values): + - __erff_data.erf contains the values of erf(r), + - __erff_data.scale contains the values of 2/sqrt(pi)*exp(-r^2). + Note that indices 0 and 1 are never hit by the algorithm, since lookup is + performed only for x >= 1/64-1/512. */ +const struct sv_erff_data __sv_erff_data = { + .erf = + { + 0x0.000000p+0, 0x0.000000p+0, 0x1.20d770p-6, 0x1.b137e0p-6, + 0x1.20c564p-5, 0x1.68e5d4p-5, 0x1.b0fafep-5, 0x1.f902a8p-5, + 0x1.207d48p-4, 0x1.44703ep-4, 0x1.68591ap-4, 0x1.8c36bep-4, + 0x1.b00812p-4, 0x1.d3cbf8p-4, 0x1.f7815ap-4, 0x1.0d9390p-3, + 0x1.1f5e1ap-3, 0x1.311fc2p-3, 0x1.42d7fcp-3, 0x1.548642p-3, + 0x1.662a0cp-3, 0x1.77c2d2p-3, 0x1.895010p-3, 0x1.9ad142p-3, + 0x1.ac45e4p-3, 0x1.bdad72p-3, 0x1.cf076ep-3, 0x1.e05354p-3, + 0x1.f190aap-3, 0x1.015f78p-2, 0x1.09eed6p-2, 0x1.127632p-2, + 0x1.1af54ep-2, 0x1.236bf0p-2, 0x1.2bd9dcp-2, 0x1.343ed6p-2, + 0x1.3c9aa8p-2, 0x1.44ed18p-2, 0x1.4d35f0p-2, 0x1.5574f4p-2, + 0x1.5da9f4p-2, 0x1.65d4b8p-2, 0x1.6df50ap-2, 0x1.760abap-2, + 0x1.7e1594p-2, 0x1.861566p-2, 0x1.8e0a02p-2, 0x1.95f336p-2, + 0x1.9dd0d2p-2, 0x1.a5a2acp-2, 0x1.ad6896p-2, 0x1.b52264p-2, + 0x1.bccfecp-2, 0x1.c47104p-2, 0x1.cc0584p-2, 0x1.d38d44p-2, + 0x1.db081cp-2, 0x1.e275eap-2, 0x1.e9d68ap-2, 0x1.f129d4p-2, + 0x1.f86faap-2, 0x1.ffa7eap-2, 0x1.03693ap-1, 0x1.06f794p-1, + 0x1.0a7ef6p-1, 0x1.0dff50p-1, 0x1.117894p-1, 0x1.14eab4p-1, + 0x1.1855a6p-1, 0x1.1bb95cp-1, 0x1.1f15ccp-1, 0x1.226ae8p-1, + 0x1.25b8a8p-1, 0x1.28ff02p-1, 0x1.2c3decp-1, 0x1.2f755cp-1, + 0x1.32a54cp-1, 0x1.35cdb4p-1, 0x1.38ee8ap-1, 0x1.3c07cap-1, + 0x1.3f196ep-1, 0x1.42236ep-1, 0x1.4525c8p-1, 0x1.482074p-1, + 0x1.4b1372p-1, 0x1.4dfebap-1, 0x1.50e24cp-1, 0x1.53be26p-1, + 0x1.569244p-1, 0x1.595ea6p-1, 0x1.5c2348p-1, 0x1.5ee02ep-1, + 0x1.619556p-1, 0x1.6442c0p-1, 0x1.66e86ep-1, 0x1.69865ep-1, + 0x1.6c1c98p-1, 0x1.6eab18p-1, 0x1.7131e6p-1, 0x1.73b102p-1, + 0x1.762870p-1, 0x1.789836p-1, 0x1.7b0058p-1, 0x1.7d60d8p-1, + 0x1.7fb9c0p-1, 0x1.820b12p-1, 0x1.8454d6p-1, 0x1.869712p-1, + 0x1.88d1cep-1, 0x1.8b050ep-1, 0x1.8d30dep-1, 0x1.8f5544p-1, + 0x1.91724ap-1, 0x1.9387f6p-1, 0x1.959652p-1, 0x1.979d68p-1, + 0x1.999d42p-1, 0x1.9b95e8p-1, 0x1.9d8768p-1, 0x1.9f71cap-1, + 0x1.a1551ap-1, 0x1.a33162p-1, 0x1.a506b0p-1, 0x1.a6d50cp-1, + 0x1.a89c86p-1, 0x1.aa5d26p-1, 0x1.ac16fcp-1, 0x1.adca14p-1, + 0x1.af767ap-1, 0x1.b11c3cp-1, 0x1.b2bb68p-1, 0x1.b4540ap-1, + 0x1.b5e630p-1, 0x1.b771e8p-1, 0x1.b8f742p-1, 0x1.ba764ap-1, + 0x1.bbef10p-1, 0x1.bd61a2p-1, 0x1.bece0ep-1, 0x1.c03464p-1, + 0x1.c194b2p-1, 0x1.c2ef08p-1, 0x1.c44376p-1, 0x1.c5920ap-1, + 0x1.c6dad2p-1, 0x1.c81de2p-1, 0x1.c95b46p-1, 0x1.ca930ep-1, + 0x1.cbc54cp-1, 0x1.ccf20cp-1, 0x1.ce1962p-1, 0x1.cf3b5cp-1, + 0x1.d0580cp-1, 0x1.d16f7ep-1, 0x1.d281c4p-1, 0x1.d38ef0p-1, + 0x1.d49710p-1, 0x1.d59a34p-1, 0x1.d6986cp-1, 0x1.d791cap-1, + 0x1.d8865ep-1, 0x1.d97636p-1, 0x1.da6162p-1, 0x1.db47f4p-1, + 0x1.dc29fcp-1, 0x1.dd0788p-1, 0x1.dde0aap-1, 0x1.deb570p-1, + 0x1.df85eap-1, 0x1.e0522ap-1, 0x1.e11a3ep-1, 0x1.e1de36p-1, + 0x1.e29e22p-1, 0x1.e35a12p-1, 0x1.e41214p-1, 0x1.e4c638p-1, + 0x1.e5768cp-1, 0x1.e62322p-1, 0x1.e6cc08p-1, 0x1.e7714ap-1, + 0x1.e812fcp-1, 0x1.e8b12ap-1, 0x1.e94be4p-1, 0x1.e9e336p-1, + 0x1.ea7730p-1, 0x1.eb07e2p-1, 0x1.eb9558p-1, 0x1.ec1fa2p-1, + 0x1.eca6ccp-1, 0x1.ed2ae6p-1, 0x1.edabfcp-1, 0x1.ee2a1ep-1, + 0x1.eea556p-1, 0x1.ef1db4p-1, 0x1.ef9344p-1, 0x1.f00614p-1, + 0x1.f07630p-1, 0x1.f0e3a6p-1, 0x1.f14e82p-1, 0x1.f1b6d0p-1, + 0x1.f21ca0p-1, 0x1.f27ff8p-1, 0x1.f2e0eap-1, 0x1.f33f7ep-1, + 0x1.f39bc2p-1, 0x1.f3f5c2p-1, 0x1.f44d88p-1, 0x1.f4a31ep-1, + 0x1.f4f694p-1, 0x1.f547f2p-1, 0x1.f59742p-1, 0x1.f5e490p-1, + 0x1.f62fe8p-1, 0x1.f67952p-1, 0x1.f6c0dcp-1, 0x1.f7068cp-1, + 0x1.f74a6ep-1, 0x1.f78c8cp-1, 0x1.f7cceep-1, 0x1.f80ba2p-1, + 0x1.f848acp-1, 0x1.f8841ap-1, 0x1.f8bdf2p-1, 0x1.f8f63ep-1, + 0x1.f92d08p-1, 0x1.f96256p-1, 0x1.f99634p-1, 0x1.f9c8a8p-1, + 0x1.f9f9bap-1, 0x1.fa2974p-1, 0x1.fa57dep-1, 0x1.fa84fep-1, + 0x1.fab0dep-1, 0x1.fadb84p-1, 0x1.fb04f6p-1, 0x1.fb2d40p-1, + 0x1.fb5464p-1, 0x1.fb7a6cp-1, 0x1.fb9f60p-1, 0x1.fbc344p-1, + 0x1.fbe61ep-1, 0x1.fc07fap-1, 0x1.fc28d8p-1, 0x1.fc48c2p-1, + 0x1.fc67bcp-1, 0x1.fc85d0p-1, 0x1.fca2fep-1, 0x1.fcbf52p-1, + 0x1.fcdaccp-1, 0x1.fcf576p-1, 0x1.fd0f54p-1, 0x1.fd286ap-1, + 0x1.fd40bep-1, 0x1.fd5856p-1, 0x1.fd6f34p-1, 0x1.fd8562p-1, + 0x1.fd9ae2p-1, 0x1.fdafb8p-1, 0x1.fdc3e8p-1, 0x1.fdd77ap-1, + 0x1.fdea6ep-1, 0x1.fdfcccp-1, 0x1.fe0e96p-1, 0x1.fe1fd0p-1, + 0x1.fe3080p-1, 0x1.fe40a6p-1, 0x1.fe504cp-1, 0x1.fe5f70p-1, + 0x1.fe6e18p-1, 0x1.fe7c46p-1, 0x1.fe8a00p-1, 0x1.fe9748p-1, + 0x1.fea422p-1, 0x1.feb090p-1, 0x1.febc96p-1, 0x1.fec836p-1, + 0x1.fed374p-1, 0x1.fede52p-1, 0x1.fee8d4p-1, 0x1.fef2fep-1, + 0x1.fefccep-1, 0x1.ff064cp-1, 0x1.ff0f76p-1, 0x1.ff1852p-1, + 0x1.ff20e0p-1, 0x1.ff2924p-1, 0x1.ff3120p-1, 0x1.ff38d6p-1, + 0x1.ff4048p-1, 0x1.ff4778p-1, 0x1.ff4e68p-1, 0x1.ff551ap-1, + 0x1.ff5b90p-1, 0x1.ff61ccp-1, 0x1.ff67d0p-1, 0x1.ff6d9ep-1, + 0x1.ff7338p-1, 0x1.ff789ep-1, 0x1.ff7dd4p-1, 0x1.ff82dap-1, + 0x1.ff87b2p-1, 0x1.ff8c5cp-1, 0x1.ff90dcp-1, 0x1.ff9532p-1, + 0x1.ff9960p-1, 0x1.ff9d68p-1, 0x1.ffa14ap-1, 0x1.ffa506p-1, + 0x1.ffa8a0p-1, 0x1.ffac18p-1, 0x1.ffaf6ep-1, 0x1.ffb2a6p-1, + 0x1.ffb5bep-1, 0x1.ffb8b8p-1, 0x1.ffbb98p-1, 0x1.ffbe5ap-1, + 0x1.ffc102p-1, 0x1.ffc390p-1, 0x1.ffc606p-1, 0x1.ffc862p-1, + 0x1.ffcaa8p-1, 0x1.ffccd8p-1, 0x1.ffcef4p-1, 0x1.ffd0fap-1, + 0x1.ffd2eap-1, 0x1.ffd4cap-1, 0x1.ffd696p-1, 0x1.ffd84ep-1, + 0x1.ffd9f8p-1, 0x1.ffdb90p-1, 0x1.ffdd18p-1, 0x1.ffde90p-1, + 0x1.ffdffap-1, 0x1.ffe154p-1, 0x1.ffe2a2p-1, 0x1.ffe3e2p-1, + 0x1.ffe514p-1, 0x1.ffe63cp-1, 0x1.ffe756p-1, 0x1.ffe866p-1, + 0x1.ffe96ap-1, 0x1.ffea64p-1, 0x1.ffeb54p-1, 0x1.ffec3ap-1, + 0x1.ffed16p-1, 0x1.ffedeap-1, 0x1.ffeeb4p-1, 0x1.ffef76p-1, + 0x1.fff032p-1, 0x1.fff0e4p-1, 0x1.fff18ep-1, 0x1.fff232p-1, + 0x1.fff2d0p-1, 0x1.fff366p-1, 0x1.fff3f6p-1, 0x1.fff480p-1, + 0x1.fff504p-1, 0x1.fff582p-1, 0x1.fff5fcp-1, 0x1.fff670p-1, + 0x1.fff6dep-1, 0x1.fff74ap-1, 0x1.fff7aep-1, 0x1.fff810p-1, + 0x1.fff86cp-1, 0x1.fff8c6p-1, 0x1.fff91cp-1, 0x1.fff96cp-1, + 0x1.fff9bap-1, 0x1.fffa04p-1, 0x1.fffa4cp-1, 0x1.fffa90p-1, + 0x1.fffad0p-1, 0x1.fffb0ep-1, 0x1.fffb4ap-1, 0x1.fffb82p-1, + 0x1.fffbb8p-1, 0x1.fffbecp-1, 0x1.fffc1ep-1, 0x1.fffc4ep-1, + 0x1.fffc7ap-1, 0x1.fffca6p-1, 0x1.fffccep-1, 0x1.fffcf6p-1, + 0x1.fffd1ap-1, 0x1.fffd3ep-1, 0x1.fffd60p-1, 0x1.fffd80p-1, + 0x1.fffda0p-1, 0x1.fffdbep-1, 0x1.fffddap-1, 0x1.fffdf4p-1, + 0x1.fffe0ep-1, 0x1.fffe26p-1, 0x1.fffe3ep-1, 0x1.fffe54p-1, + 0x1.fffe68p-1, 0x1.fffe7ep-1, 0x1.fffe90p-1, 0x1.fffea2p-1, + 0x1.fffeb4p-1, 0x1.fffec4p-1, 0x1.fffed4p-1, 0x1.fffee4p-1, + 0x1.fffef2p-1, 0x1.ffff00p-1, 0x1.ffff0cp-1, 0x1.ffff18p-1, + 0x1.ffff24p-1, 0x1.ffff30p-1, 0x1.ffff3ap-1, 0x1.ffff44p-1, + 0x1.ffff4ep-1, 0x1.ffff56p-1, 0x1.ffff60p-1, 0x1.ffff68p-1, + 0x1.ffff70p-1, 0x1.ffff78p-1, 0x1.ffff7ep-1, 0x1.ffff84p-1, + 0x1.ffff8cp-1, 0x1.ffff92p-1, 0x1.ffff98p-1, 0x1.ffff9cp-1, + 0x1.ffffa2p-1, 0x1.ffffa6p-1, 0x1.ffffacp-1, 0x1.ffffb0p-1, + 0x1.ffffb4p-1, 0x1.ffffb8p-1, 0x1.ffffbcp-1, 0x1.ffffc0p-1, + 0x1.ffffc4p-1, 0x1.ffffc6p-1, 0x1.ffffcap-1, 0x1.ffffccp-1, + 0x1.ffffd0p-1, 0x1.ffffd2p-1, 0x1.ffffd4p-1, 0x1.ffffd6p-1, + 0x1.ffffd8p-1, 0x1.ffffdcp-1, 0x1.ffffdep-1, 0x1.ffffdep-1, + 0x1.ffffe0p-1, 0x1.ffffe2p-1, 0x1.ffffe4p-1, 0x1.ffffe6p-1, + 0x1.ffffe8p-1, 0x1.ffffe8p-1, 0x1.ffffeap-1, 0x1.ffffeap-1, + 0x1.ffffecp-1, 0x1.ffffeep-1, 0x1.ffffeep-1, 0x1.fffff0p-1, + 0x1.fffff0p-1, 0x1.fffff2p-1, 0x1.fffff2p-1, 0x1.fffff2p-1, + 0x1.fffff4p-1, 0x1.fffff4p-1, 0x1.fffff4p-1, 0x1.fffff6p-1, + 0x1.fffff6p-1, 0x1.fffff6p-1, 0x1.fffff8p-1, 0x1.fffff8p-1, + 0x1.fffff8p-1, 0x1.fffff8p-1, 0x1.fffffap-1, 0x1.fffffap-1, + 0x1.fffffap-1, 0x1.fffffap-1, 0x1.fffffap-1, 0x1.fffffap-1, + 0x1.fffffcp-1, 0x1.fffffcp-1, 0x1.fffffcp-1, 0x1.fffffcp-1, + 0x1.fffffcp-1, 0x1.fffffcp-1, 0x1.fffffcp-1, 0x1.fffffcp-1, + 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, + 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, + 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, + 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, + 0x1.fffffep-1, 0x1.fffffep-1, 0x1.000000p+0, 0x1.000000p+0, + 0x1.000000p+0, 0x1.000000p+0, 0x1.000000p+0, 0x1.000000p+0, + 0x1.000000p+0, 0x1.000000p+0, 0x1.000000p+0, 0x1.000000p+0, + 0x1.000000p+0, + }, + .scale = + { + 0x1.20dd76p+0, 0x1.20dd76p+0, 0x1.20cb68p+0, 0x1.20b4d8p+0, + 0x1.209546p+0, 0x1.206cb4p+0, 0x1.203b26p+0, 0x1.2000a0p+0, + 0x1.1fbd28p+0, 0x1.1f70c4p+0, 0x1.1f1b7ap+0, 0x1.1ebd56p+0, + 0x1.1e565cp+0, 0x1.1de698p+0, 0x1.1d6e14p+0, 0x1.1cecdcp+0, + 0x1.1c62fap+0, 0x1.1bd07cp+0, 0x1.1b3572p+0, 0x1.1a91e6p+0, + 0x1.19e5eap+0, 0x1.19318cp+0, 0x1.1874dep+0, 0x1.17aff0p+0, + 0x1.16e2d8p+0, 0x1.160da4p+0, 0x1.153068p+0, 0x1.144b3cp+0, + 0x1.135e30p+0, 0x1.12695ep+0, 0x1.116cd8p+0, 0x1.1068bap+0, + 0x1.0f5d16p+0, 0x1.0e4a08p+0, 0x1.0d2fa6p+0, 0x1.0c0e0ap+0, + 0x1.0ae550p+0, 0x1.09b590p+0, 0x1.087ee4p+0, 0x1.07416cp+0, + 0x1.05fd3ep+0, 0x1.04b27cp+0, 0x1.036140p+0, 0x1.0209a6p+0, + 0x1.00abd0p+0, 0x1.fe8fb0p-1, 0x1.fbbbbep-1, 0x1.f8dc0ap-1, + 0x1.f5f0cep-1, 0x1.f2fa4cp-1, 0x1.eff8c4p-1, 0x1.ecec78p-1, + 0x1.e9d5a8p-1, 0x1.e6b498p-1, 0x1.e38988p-1, 0x1.e054bep-1, + 0x1.dd167cp-1, 0x1.d9cf06p-1, 0x1.d67ea2p-1, 0x1.d32592p-1, + 0x1.cfc41ep-1, 0x1.cc5a8ap-1, 0x1.c8e91cp-1, 0x1.c5701ap-1, + 0x1.c1efcap-1, 0x1.be6872p-1, 0x1.bada5ap-1, 0x1.b745c6p-1, + 0x1.b3aafcp-1, 0x1.b00a46p-1, 0x1.ac63e8p-1, 0x1.a8b828p-1, + 0x1.a5074ep-1, 0x1.a1519ep-1, 0x1.9d9762p-1, 0x1.99d8dap-1, + 0x1.961650p-1, 0x1.925008p-1, 0x1.8e8646p-1, 0x1.8ab950p-1, + 0x1.86e96ap-1, 0x1.8316d6p-1, 0x1.7f41dcp-1, 0x1.7b6abcp-1, + 0x1.7791b8p-1, 0x1.73b714p-1, 0x1.6fdb12p-1, 0x1.6bfdf0p-1, + 0x1.681ff2p-1, 0x1.644156p-1, 0x1.60625cp-1, 0x1.5c8342p-1, + 0x1.58a446p-1, 0x1.54c5a6p-1, 0x1.50e79ep-1, 0x1.4d0a68p-1, + 0x1.492e42p-1, 0x1.455366p-1, 0x1.417a0cp-1, 0x1.3da26ep-1, + 0x1.39ccc2p-1, 0x1.35f940p-1, 0x1.32281ep-1, 0x1.2e5992p-1, + 0x1.2a8dcep-1, 0x1.26c508p-1, 0x1.22ff72p-1, 0x1.1f3d3cp-1, + 0x1.1b7e98p-1, 0x1.17c3b6p-1, 0x1.140cc4p-1, 0x1.1059eep-1, + 0x1.0cab62p-1, 0x1.09014cp-1, 0x1.055bd6p-1, 0x1.01bb2cp-1, + 0x1.fc3ee6p-2, 0x1.f511aap-2, 0x1.edeeeep-2, 0x1.e6d700p-2, + 0x1.dfca26p-2, 0x1.d8c8aap-2, 0x1.d1d2d0p-2, 0x1.cae8dap-2, + 0x1.c40b08p-2, 0x1.bd3998p-2, 0x1.b674c8p-2, 0x1.afbcd4p-2, + 0x1.a911f0p-2, 0x1.a27456p-2, 0x1.9be438p-2, 0x1.9561c8p-2, + 0x1.8eed36p-2, 0x1.8886b2p-2, 0x1.822e66p-2, 0x1.7be47ap-2, + 0x1.75a91ap-2, 0x1.6f7c6ap-2, 0x1.695e8cp-2, 0x1.634fa6p-2, + 0x1.5d4fd4p-2, 0x1.575f34p-2, 0x1.517de6p-2, 0x1.4bac00p-2, + 0x1.45e99cp-2, 0x1.4036d0p-2, 0x1.3a93b2p-2, 0x1.350052p-2, + 0x1.2f7cc4p-2, 0x1.2a0916p-2, 0x1.24a554p-2, 0x1.1f518ap-2, + 0x1.1a0dc6p-2, 0x1.14da0ap-2, 0x1.0fb662p-2, 0x1.0aa2d0p-2, + 0x1.059f5ap-2, 0x1.00ac00p-2, 0x1.f79184p-3, 0x1.edeb40p-3, + 0x1.e46530p-3, 0x1.daff4ap-3, 0x1.d1b982p-3, 0x1.c893cep-3, + 0x1.bf8e1cp-3, 0x1.b6a856p-3, 0x1.ade26cp-3, 0x1.a53c42p-3, + 0x1.9cb5bep-3, 0x1.944ec2p-3, 0x1.8c0732p-3, 0x1.83deeap-3, + 0x1.7bd5c8p-3, 0x1.73eba4p-3, 0x1.6c2056p-3, 0x1.6473b6p-3, + 0x1.5ce596p-3, 0x1.5575c8p-3, 0x1.4e241ep-3, 0x1.46f066p-3, + 0x1.3fda6cp-3, 0x1.38e1fap-3, 0x1.3206dcp-3, 0x1.2b48dap-3, + 0x1.24a7b8p-3, 0x1.1e233ep-3, 0x1.17bb2cp-3, 0x1.116f48p-3, + 0x1.0b3f52p-3, 0x1.052b0cp-3, 0x1.fe6460p-4, 0x1.f2a902p-4, + 0x1.e72372p-4, 0x1.dbd32ap-4, 0x1.d0b7a0p-4, 0x1.c5d04ap-4, + 0x1.bb1c98p-4, 0x1.b09bfcp-4, 0x1.a64de6p-4, 0x1.9c31c6p-4, + 0x1.92470ap-4, 0x1.888d1ep-4, 0x1.7f036cp-4, 0x1.75a960p-4, + 0x1.6c7e64p-4, 0x1.6381e2p-4, 0x1.5ab342p-4, 0x1.5211ecp-4, + 0x1.499d48p-4, 0x1.4154bcp-4, 0x1.3937b2p-4, 0x1.31458ep-4, + 0x1.297dbap-4, 0x1.21df9ap-4, 0x1.1a6a96p-4, 0x1.131e14p-4, + 0x1.0bf97ep-4, 0x1.04fc3ap-4, 0x1.fc4b5ep-5, 0x1.eeea8cp-5, + 0x1.e1d4d0p-5, 0x1.d508fap-5, 0x1.c885e0p-5, 0x1.bc4a54p-5, + 0x1.b05530p-5, 0x1.a4a54ap-5, 0x1.99397ap-5, 0x1.8e109cp-5, + 0x1.83298ep-5, 0x1.78832cp-5, 0x1.6e1c58p-5, 0x1.63f3f6p-5, + 0x1.5a08e8p-5, 0x1.505a18p-5, 0x1.46e66cp-5, 0x1.3dacd2p-5, + 0x1.34ac36p-5, 0x1.2be38cp-5, 0x1.2351c2p-5, 0x1.1af5d2p-5, + 0x1.12ceb4p-5, 0x1.0adb60p-5, 0x1.031ad6p-5, 0x1.f7182ap-6, + 0x1.e85c44p-6, 0x1.da0006p-6, 0x1.cc0180p-6, 0x1.be5ecep-6, + 0x1.b1160ap-6, 0x1.a4255ap-6, 0x1.978ae8p-6, 0x1.8b44e6p-6, + 0x1.7f5188p-6, 0x1.73af0cp-6, 0x1.685bb6p-6, 0x1.5d55ccp-6, + 0x1.529b9ep-6, 0x1.482b84p-6, 0x1.3e03d8p-6, 0x1.3422fep-6, + 0x1.2a875cp-6, 0x1.212f62p-6, 0x1.181984p-6, 0x1.0f443ep-6, + 0x1.06ae14p-6, 0x1.fcab14p-7, 0x1.ec7262p-7, 0x1.dcaf36p-7, + 0x1.cd5ecap-7, 0x1.be7e5ap-7, 0x1.b00b38p-7, 0x1.a202bep-7, + 0x1.94624ep-7, 0x1.87275ep-7, 0x1.7a4f6ap-7, 0x1.6dd7fep-7, + 0x1.61beaep-7, 0x1.56011cp-7, 0x1.4a9cf6p-7, 0x1.3f8ff6p-7, + 0x1.34d7dcp-7, 0x1.2a727ap-7, 0x1.205dacp-7, 0x1.169756p-7, + 0x1.0d1d6ap-7, 0x1.03ede2p-7, 0x1.f60d8ap-8, 0x1.e4cc4ap-8, + 0x1.d4143ap-8, 0x1.c3e1a6p-8, 0x1.b430ecp-8, 0x1.a4fe84p-8, + 0x1.9646f4p-8, 0x1.8806d8p-8, 0x1.7a3adep-8, 0x1.6cdfccp-8, + 0x1.5ff276p-8, 0x1.536fc2p-8, 0x1.4754acp-8, 0x1.3b9e40p-8, + 0x1.30499cp-8, 0x1.2553eep-8, 0x1.1aba78p-8, 0x1.107a8cp-8, + 0x1.06918cp-8, 0x1.f9f9d0p-9, 0x1.e77448p-9, 0x1.d58da6p-9, + 0x1.c4412cp-9, 0x1.b38a3ap-9, 0x1.a36454p-9, 0x1.93cb12p-9, + 0x1.84ba30p-9, 0x1.762d84p-9, 0x1.682100p-9, 0x1.5a90b0p-9, + 0x1.4d78bcp-9, 0x1.40d564p-9, 0x1.34a306p-9, 0x1.28de12p-9, + 0x1.1d8318p-9, 0x1.128ebap-9, 0x1.07fdb4p-9, 0x1.fb99b8p-10, + 0x1.e7f232p-10, 0x1.d4fed8p-10, 0x1.c2b9d0p-10, 0x1.b11d70p-10, + 0x1.a02436p-10, 0x1.8fc8c8p-10, 0x1.8005f0p-10, 0x1.70d6a4p-10, + 0x1.6235fcp-10, 0x1.541f34p-10, 0x1.468daep-10, 0x1.397ceep-10, + 0x1.2ce898p-10, 0x1.20cc76p-10, 0x1.15246ep-10, 0x1.09ec86p-10, + 0x1.fe41cep-11, 0x1.e97ba4p-11, 0x1.d57f52p-11, 0x1.c245d4p-11, + 0x1.afc85ep-11, 0x1.9e0058p-11, 0x1.8ce75ep-11, 0x1.7c7744p-11, + 0x1.6caa0ep-11, 0x1.5d79ecp-11, 0x1.4ee142p-11, 0x1.40daa4p-11, + 0x1.3360ccp-11, 0x1.266ea8p-11, 0x1.19ff46p-11, 0x1.0e0de8p-11, + 0x1.0295f0p-11, 0x1.ef25d4p-12, 0x1.da0110p-12, 0x1.c5b542p-12, + 0x1.b23a5ap-12, 0x1.9f8894p-12, 0x1.8d986ap-12, 0x1.7c629ap-12, + 0x1.6be022p-12, 0x1.5c0a38p-12, 0x1.4cda54p-12, 0x1.3e4a24p-12, + 0x1.305390p-12, 0x1.22f0b4p-12, 0x1.161be4p-12, 0x1.09cfa4p-12, + 0x1.fc0d56p-13, 0x1.e577bcp-13, 0x1.cfd4a6p-13, 0x1.bb1a96p-13, + 0x1.a74068p-13, 0x1.943d4ap-13, 0x1.8208bcp-13, 0x1.709a8ep-13, + 0x1.5feadap-13, 0x1.4ff208p-13, 0x1.40a8c2p-13, 0x1.3207fcp-13, + 0x1.2408eap-13, 0x1.16a502p-13, 0x1.09d5f8p-13, 0x1.fb2b7ap-14, + 0x1.e3bcf4p-14, 0x1.cd5528p-14, 0x1.b7e946p-14, 0x1.a36eecp-14, + 0x1.8fdc1cp-14, 0x1.7d2738p-14, 0x1.6b4702p-14, 0x1.5a329cp-14, + 0x1.49e178p-14, 0x1.3a4b60p-14, 0x1.2b6876p-14, 0x1.1d3120p-14, + 0x1.0f9e1cp-14, 0x1.02a868p-14, 0x1.ec929ap-15, 0x1.d4f4b4p-15, + 0x1.be6abcp-15, 0x1.a8e8ccp-15, 0x1.94637ep-15, 0x1.80cfdcp-15, + 0x1.6e2368p-15, 0x1.5c540cp-15, 0x1.4b581cp-15, 0x1.3b2652p-15, + 0x1.2bb5ccp-15, 0x1.1cfe02p-15, 0x1.0ef6c4p-15, 0x1.019842p-15, + 0x1.e9b5e8p-16, 0x1.d16f58p-16, 0x1.ba4f04p-16, 0x1.a447b8p-16, + 0x1.8f4cccp-16, 0x1.7b5224p-16, 0x1.684c22p-16, 0x1.562facp-16, + 0x1.44f21ep-16, 0x1.34894ap-16, 0x1.24eb72p-16, 0x1.160f44p-16, + 0x1.07ebd2p-16, 0x1.f4f12ep-17, 0x1.db5ad0p-17, 0x1.c304f0p-17, + 0x1.abe09ep-17, 0x1.95df98p-17, 0x1.80f43ap-17, 0x1.6d1178p-17, + 0x1.5a2ae0p-17, 0x1.483488p-17, 0x1.372310p-17, 0x1.26eb9ep-17, + 0x1.1783cep-17, 0x1.08e1bap-17, 0x1.f5f7d8p-18, 0x1.db92b6p-18, + 0x1.c282cep-18, 0x1.aab7acp-18, 0x1.94219cp-18, 0x1.7eb1a2p-18, + 0x1.6a5972p-18, 0x1.570b6ap-18, 0x1.44ba86p-18, 0x1.335a62p-18, + 0x1.22df2ap-18, 0x1.133d96p-18, 0x1.046aeap-18, 0x1.ecb9d0p-19, + 0x1.d21398p-19, 0x1.b8d094p-19, 0x1.a0df10p-19, 0x1.8a2e26p-19, + 0x1.74adc8p-19, 0x1.604ea8p-19, 0x1.4d0232p-19, 0x1.3aba86p-19, + 0x1.296a70p-19, 0x1.190562p-19, 0x1.097f62p-19, 0x1.f59a20p-20, + 0x1.d9c736p-20, 0x1.bf716cp-20, 0x1.a6852cp-20, 0x1.8eefd8p-20, + 0x1.789fb8p-20, 0x1.6383f8p-20, 0x1.4f8c96p-20, 0x1.3caa62p-20, + 0x1.2acee2p-20, 0x1.19ec60p-20, 0x1.09f5d0p-20, 0x1.f5bd96p-21, + 0x1.d9371ep-21, 0x1.be41dep-21, 0x1.a4c89ep-21, 0x1.8cb738p-21, + 0x1.75fa8ep-21, 0x1.608078p-21, 0x1.4c37c0p-21, 0x1.39100ep-21, + 0x1.26f9e0p-21, 0x1.15e682p-21, 0x1.05c804p-21, 0x1.ed2254p-22, + 0x1.d06ad6p-22, 0x1.b551c8p-22, 0x1.9bc0a0p-22, 0x1.83a200p-22, + 0x1.6ce1aap-22, 0x1.576c72p-22, 0x1.43302cp-22, 0x1.301ba2p-22, + 0x1.1e1e86p-22, 0x1.0d2966p-22, 0x1.fa5b50p-23, 0x1.dc3ae4p-23, + 0x1.bfd756p-23, 0x1.a517dap-23, 0x1.8be4f8p-23, 0x1.74287ep-23, + 0x1.5dcd66p-23, 0x1.48bfd4p-23, 0x1.34ecf8p-23, 0x1.224310p-23, + 0x1.10b148p-23, + }, +}; + +#define _RVV_FLOAT_ERF_OP(LMUL, MLEN, TLEN) \ + static inline vfloat##TLEN##m##LMUL##_t erf_ps( \ + fixed_vfloat##TLEN##m##LMUL##_t x, size_t vl) { \ + auto zero = __riscv_vmv_v_x_u##TLEN##m##LMUL(0, vl); \ + auto a = __riscv_vfabs_v_f##TLEN##m##LMUL(x, vl); \ + \ + /* |x| > 1/64 - 1/512. */ \ + auto gt_min_mask = \ + __riscv_vmfgt_vf_f##TLEN##m##LMUL##_b##MLEN(a, 0x1.cp-7f, vl); \ + \ + auto tmp_i = __riscv_vfmul_vf_f##TLEN##m##LMUL(a, 128.f, vl); \ + auto i = __riscv_vfcvt_xu_f_v_u##TLEN##m##LMUL(tmp_i, vl); \ + \ + /* Saturate lookup index. */ \ + i = __riscv_vmerge_vvm_u##TLEN##m##LMUL(zero, i, gt_min_mask, vl); \ + i = __riscv_vminu_vx_u##TLEN##m##LMUL(i, 512, vl); \ + auto tmp_r = __riscv_vfcvt_f_xu_v_f##TLEN##m##LMUL(i, vl); \ + i = __riscv_vmul_vx_u##TLEN##m##LMUL(i, TLEN / 8, vl); \ + \ + /* r and erf(r) set to 0 for |x| below min. */ \ + auto r = __riscv_vfmul_vf_f##TLEN##m##LMUL(tmp_r, 1.f / 128, vl); \ + auto erfr = __riscv_vluxei##TLEN##_v_f##TLEN##m##LMUL( \ + __sv_erff_data.erf, i, vl); \ + auto scale = __riscv_vluxei##TLEN##_v_f##TLEN##m##LMUL( \ + __sv_erff_data.scale, i, vl); \ + \ + /* |x| >= 4.0 - 8/128. */ \ + auto ge_max_mask = \ + __riscv_vmfge_vf_f##TLEN##m##LMUL##_b##MLEN(a, 3.9375f, vl); \ + \ + /* erf(x) ~ erf(r) + scale * d * (1 - r * d - 1/3 * d^2). */ \ + auto d = __riscv_vfsub_vv_f##TLEN##m##LMUL(a, r, vl); \ + auto d2 = __riscv_vfmul_vv_f##TLEN##m##LMUL(d, d, vl); \ + auto y = __riscv_vfmacc_vf_f##TLEN##m##LMUL(r, 0x1.555556p-2f, d, vl); \ + y = __riscv_vfnmsub_vv_f##TLEN##m##LMUL(y, d2, d, vl); \ + y = __riscv_vfmadd_vv_f##TLEN##m##LMUL(y, scale, erfr, vl); \ + \ + /* Solves the |x| = inf case. */ \ + y = __riscv_vfmerge_vfm_f##TLEN##m##LMUL(y, 1.f, ge_max_mask, vl); \ + \ + /* Copy sign. */ \ + return __riscv_vfsgnj_vv_f##TLEN##m##LMUL(y, x, vl); \ + } + +_RVV_FLOAT_ERF_OP(1, 32, 32) +_RVV_FLOAT_ERF_OP(2, 16, 32) +_RVV_FLOAT_ERF_OP(4, 8, 32) +_RVV_FLOAT_ERF_OP(8, 4, 32) #endif \ No newline at end of file diff --git a/src/Native/include/nncase/ntt/arch/riscv64/ukernels.h b/src/Native/include/nncase/ntt/arch/riscv64/ukernels.h new file mode 100644 index 0000000000..c9649881da --- /dev/null +++ b/src/Native/include/nncase/ntt/arch/riscv64/ukernels.h @@ -0,0 +1,101 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../../ukernels.h" +#include "nncase/ntt/arch/riscv64/arch_types.h" +#include "nncase/ntt/vector.h" +#include + +namespace nncase::ntt::ukernels { +template struct u_reduce_policy { + static constexpr size_t unroll = 8; +}; + +template <> +struct u_matmul_policy { + static constexpr size_t m0_tile = 1; + static constexpr size_t n0_tile = 1; + static constexpr size_t m0_subtile = 0; +}; + +// Pack M +template <> +struct u_matmul_policy, + float, vector, true> { + static constexpr size_t m0_tile = 2; + static constexpr size_t n0_tile = 4; + static constexpr size_t m0_subtile = 0; +}; + +// Pack K +template <> +struct u_matmul_policy, + vector, float, true> { + static constexpr size_t m0_tile = 2; + static constexpr size_t n0_tile = 2; + static constexpr size_t m0_subtile = 0; +}; + +// Pack N +template <> +struct u_matmul_policy, + vector, true> { + static constexpr size_t m0_tile = 4; + static constexpr size_t n0_tile = 2; + static constexpr size_t m0_subtile = 0; +}; + +// Pack MN +template <> +struct u_matmul_policy, + vector, + vector, true> { + static constexpr size_t m0_tile = 1; + static constexpr size_t n0_tile = 2; + static constexpr size_t m0_subtile = 4; +}; + +// Pack MK +template <> +struct u_matmul_policy< + mamtul_pack_kind::pack_mk, vector, + vector, vector, true> { + static constexpr size_t m0_tile = 1; + static constexpr size_t n0_tile = 1; + static constexpr size_t m0_subtile = 0; +}; + +// Pack KN +template <> +struct u_matmul_policy, + vector, + vector, true> { + static constexpr size_t m0_tile = 4; + static constexpr size_t n0_tile = 2; + static constexpr size_t m0_subtile = 0; +}; + +// Pack MKN +template <> +struct u_matmul_policy, + vector, + vector, true> { + static constexpr size_t m0_tile = 1; + static constexpr size_t n0_tile = 2; + static constexpr size_t m0_subtile = 4; +}; +} // namespace nncase::ntt::ukernels diff --git a/src/Native/include/nncase/ntt/arch/x86_64/avx_mathfun.h b/src/Native/include/nncase/ntt/arch/x86_64/avx_mathfun.h index 221e0b3cae..0a23248cb6 100644 --- a/src/Native/include/nncase/ntt/arch/x86_64/avx_mathfun.h +++ b/src/Native/include/nncase/ntt/arch/x86_64/avx_mathfun.h @@ -992,6 +992,324 @@ static inline __m256 atan256_ps(__m256 x) { negative_mask); } +struct sv_erff_data { + float erf[513]; + float scale[513]; +}; + +/* Lookup table used in SVE erff. + For each possible rounded input r (multiples of 1/128), between + r = 0.0 and r = 4.0 (513 values): + - __erff_data.erf contains the values of erf(r), + - __erff_data.scale contains the values of 2/sqrt(pi)*exp(-r^2). + Note that indices 0 and 1 are never hit by the algorithm, since lookup is + performed only for x >= 1/64-1/512. */ +const struct sv_erff_data __sv_erff_data = { + .erf = + { + 0x0.000000p+0, 0x0.000000p+0, 0x1.20d770p-6, 0x1.b137e0p-6, + 0x1.20c564p-5, 0x1.68e5d4p-5, 0x1.b0fafep-5, 0x1.f902a8p-5, + 0x1.207d48p-4, 0x1.44703ep-4, 0x1.68591ap-4, 0x1.8c36bep-4, + 0x1.b00812p-4, 0x1.d3cbf8p-4, 0x1.f7815ap-4, 0x1.0d9390p-3, + 0x1.1f5e1ap-3, 0x1.311fc2p-3, 0x1.42d7fcp-3, 0x1.548642p-3, + 0x1.662a0cp-3, 0x1.77c2d2p-3, 0x1.895010p-3, 0x1.9ad142p-3, + 0x1.ac45e4p-3, 0x1.bdad72p-3, 0x1.cf076ep-3, 0x1.e05354p-3, + 0x1.f190aap-3, 0x1.015f78p-2, 0x1.09eed6p-2, 0x1.127632p-2, + 0x1.1af54ep-2, 0x1.236bf0p-2, 0x1.2bd9dcp-2, 0x1.343ed6p-2, + 0x1.3c9aa8p-2, 0x1.44ed18p-2, 0x1.4d35f0p-2, 0x1.5574f4p-2, + 0x1.5da9f4p-2, 0x1.65d4b8p-2, 0x1.6df50ap-2, 0x1.760abap-2, + 0x1.7e1594p-2, 0x1.861566p-2, 0x1.8e0a02p-2, 0x1.95f336p-2, + 0x1.9dd0d2p-2, 0x1.a5a2acp-2, 0x1.ad6896p-2, 0x1.b52264p-2, + 0x1.bccfecp-2, 0x1.c47104p-2, 0x1.cc0584p-2, 0x1.d38d44p-2, + 0x1.db081cp-2, 0x1.e275eap-2, 0x1.e9d68ap-2, 0x1.f129d4p-2, + 0x1.f86faap-2, 0x1.ffa7eap-2, 0x1.03693ap-1, 0x1.06f794p-1, + 0x1.0a7ef6p-1, 0x1.0dff50p-1, 0x1.117894p-1, 0x1.14eab4p-1, + 0x1.1855a6p-1, 0x1.1bb95cp-1, 0x1.1f15ccp-1, 0x1.226ae8p-1, + 0x1.25b8a8p-1, 0x1.28ff02p-1, 0x1.2c3decp-1, 0x1.2f755cp-1, + 0x1.32a54cp-1, 0x1.35cdb4p-1, 0x1.38ee8ap-1, 0x1.3c07cap-1, + 0x1.3f196ep-1, 0x1.42236ep-1, 0x1.4525c8p-1, 0x1.482074p-1, + 0x1.4b1372p-1, 0x1.4dfebap-1, 0x1.50e24cp-1, 0x1.53be26p-1, + 0x1.569244p-1, 0x1.595ea6p-1, 0x1.5c2348p-1, 0x1.5ee02ep-1, + 0x1.619556p-1, 0x1.6442c0p-1, 0x1.66e86ep-1, 0x1.69865ep-1, + 0x1.6c1c98p-1, 0x1.6eab18p-1, 0x1.7131e6p-1, 0x1.73b102p-1, + 0x1.762870p-1, 0x1.789836p-1, 0x1.7b0058p-1, 0x1.7d60d8p-1, + 0x1.7fb9c0p-1, 0x1.820b12p-1, 0x1.8454d6p-1, 0x1.869712p-1, + 0x1.88d1cep-1, 0x1.8b050ep-1, 0x1.8d30dep-1, 0x1.8f5544p-1, + 0x1.91724ap-1, 0x1.9387f6p-1, 0x1.959652p-1, 0x1.979d68p-1, + 0x1.999d42p-1, 0x1.9b95e8p-1, 0x1.9d8768p-1, 0x1.9f71cap-1, + 0x1.a1551ap-1, 0x1.a33162p-1, 0x1.a506b0p-1, 0x1.a6d50cp-1, + 0x1.a89c86p-1, 0x1.aa5d26p-1, 0x1.ac16fcp-1, 0x1.adca14p-1, + 0x1.af767ap-1, 0x1.b11c3cp-1, 0x1.b2bb68p-1, 0x1.b4540ap-1, + 0x1.b5e630p-1, 0x1.b771e8p-1, 0x1.b8f742p-1, 0x1.ba764ap-1, + 0x1.bbef10p-1, 0x1.bd61a2p-1, 0x1.bece0ep-1, 0x1.c03464p-1, + 0x1.c194b2p-1, 0x1.c2ef08p-1, 0x1.c44376p-1, 0x1.c5920ap-1, + 0x1.c6dad2p-1, 0x1.c81de2p-1, 0x1.c95b46p-1, 0x1.ca930ep-1, + 0x1.cbc54cp-1, 0x1.ccf20cp-1, 0x1.ce1962p-1, 0x1.cf3b5cp-1, + 0x1.d0580cp-1, 0x1.d16f7ep-1, 0x1.d281c4p-1, 0x1.d38ef0p-1, + 0x1.d49710p-1, 0x1.d59a34p-1, 0x1.d6986cp-1, 0x1.d791cap-1, + 0x1.d8865ep-1, 0x1.d97636p-1, 0x1.da6162p-1, 0x1.db47f4p-1, + 0x1.dc29fcp-1, 0x1.dd0788p-1, 0x1.dde0aap-1, 0x1.deb570p-1, + 0x1.df85eap-1, 0x1.e0522ap-1, 0x1.e11a3ep-1, 0x1.e1de36p-1, + 0x1.e29e22p-1, 0x1.e35a12p-1, 0x1.e41214p-1, 0x1.e4c638p-1, + 0x1.e5768cp-1, 0x1.e62322p-1, 0x1.e6cc08p-1, 0x1.e7714ap-1, + 0x1.e812fcp-1, 0x1.e8b12ap-1, 0x1.e94be4p-1, 0x1.e9e336p-1, + 0x1.ea7730p-1, 0x1.eb07e2p-1, 0x1.eb9558p-1, 0x1.ec1fa2p-1, + 0x1.eca6ccp-1, 0x1.ed2ae6p-1, 0x1.edabfcp-1, 0x1.ee2a1ep-1, + 0x1.eea556p-1, 0x1.ef1db4p-1, 0x1.ef9344p-1, 0x1.f00614p-1, + 0x1.f07630p-1, 0x1.f0e3a6p-1, 0x1.f14e82p-1, 0x1.f1b6d0p-1, + 0x1.f21ca0p-1, 0x1.f27ff8p-1, 0x1.f2e0eap-1, 0x1.f33f7ep-1, + 0x1.f39bc2p-1, 0x1.f3f5c2p-1, 0x1.f44d88p-1, 0x1.f4a31ep-1, + 0x1.f4f694p-1, 0x1.f547f2p-1, 0x1.f59742p-1, 0x1.f5e490p-1, + 0x1.f62fe8p-1, 0x1.f67952p-1, 0x1.f6c0dcp-1, 0x1.f7068cp-1, + 0x1.f74a6ep-1, 0x1.f78c8cp-1, 0x1.f7cceep-1, 0x1.f80ba2p-1, + 0x1.f848acp-1, 0x1.f8841ap-1, 0x1.f8bdf2p-1, 0x1.f8f63ep-1, + 0x1.f92d08p-1, 0x1.f96256p-1, 0x1.f99634p-1, 0x1.f9c8a8p-1, + 0x1.f9f9bap-1, 0x1.fa2974p-1, 0x1.fa57dep-1, 0x1.fa84fep-1, + 0x1.fab0dep-1, 0x1.fadb84p-1, 0x1.fb04f6p-1, 0x1.fb2d40p-1, + 0x1.fb5464p-1, 0x1.fb7a6cp-1, 0x1.fb9f60p-1, 0x1.fbc344p-1, + 0x1.fbe61ep-1, 0x1.fc07fap-1, 0x1.fc28d8p-1, 0x1.fc48c2p-1, + 0x1.fc67bcp-1, 0x1.fc85d0p-1, 0x1.fca2fep-1, 0x1.fcbf52p-1, + 0x1.fcdaccp-1, 0x1.fcf576p-1, 0x1.fd0f54p-1, 0x1.fd286ap-1, + 0x1.fd40bep-1, 0x1.fd5856p-1, 0x1.fd6f34p-1, 0x1.fd8562p-1, + 0x1.fd9ae2p-1, 0x1.fdafb8p-1, 0x1.fdc3e8p-1, 0x1.fdd77ap-1, + 0x1.fdea6ep-1, 0x1.fdfcccp-1, 0x1.fe0e96p-1, 0x1.fe1fd0p-1, + 0x1.fe3080p-1, 0x1.fe40a6p-1, 0x1.fe504cp-1, 0x1.fe5f70p-1, + 0x1.fe6e18p-1, 0x1.fe7c46p-1, 0x1.fe8a00p-1, 0x1.fe9748p-1, + 0x1.fea422p-1, 0x1.feb090p-1, 0x1.febc96p-1, 0x1.fec836p-1, + 0x1.fed374p-1, 0x1.fede52p-1, 0x1.fee8d4p-1, 0x1.fef2fep-1, + 0x1.fefccep-1, 0x1.ff064cp-1, 0x1.ff0f76p-1, 0x1.ff1852p-1, + 0x1.ff20e0p-1, 0x1.ff2924p-1, 0x1.ff3120p-1, 0x1.ff38d6p-1, + 0x1.ff4048p-1, 0x1.ff4778p-1, 0x1.ff4e68p-1, 0x1.ff551ap-1, + 0x1.ff5b90p-1, 0x1.ff61ccp-1, 0x1.ff67d0p-1, 0x1.ff6d9ep-1, + 0x1.ff7338p-1, 0x1.ff789ep-1, 0x1.ff7dd4p-1, 0x1.ff82dap-1, + 0x1.ff87b2p-1, 0x1.ff8c5cp-1, 0x1.ff90dcp-1, 0x1.ff9532p-1, + 0x1.ff9960p-1, 0x1.ff9d68p-1, 0x1.ffa14ap-1, 0x1.ffa506p-1, + 0x1.ffa8a0p-1, 0x1.ffac18p-1, 0x1.ffaf6ep-1, 0x1.ffb2a6p-1, + 0x1.ffb5bep-1, 0x1.ffb8b8p-1, 0x1.ffbb98p-1, 0x1.ffbe5ap-1, + 0x1.ffc102p-1, 0x1.ffc390p-1, 0x1.ffc606p-1, 0x1.ffc862p-1, + 0x1.ffcaa8p-1, 0x1.ffccd8p-1, 0x1.ffcef4p-1, 0x1.ffd0fap-1, + 0x1.ffd2eap-1, 0x1.ffd4cap-1, 0x1.ffd696p-1, 0x1.ffd84ep-1, + 0x1.ffd9f8p-1, 0x1.ffdb90p-1, 0x1.ffdd18p-1, 0x1.ffde90p-1, + 0x1.ffdffap-1, 0x1.ffe154p-1, 0x1.ffe2a2p-1, 0x1.ffe3e2p-1, + 0x1.ffe514p-1, 0x1.ffe63cp-1, 0x1.ffe756p-1, 0x1.ffe866p-1, + 0x1.ffe96ap-1, 0x1.ffea64p-1, 0x1.ffeb54p-1, 0x1.ffec3ap-1, + 0x1.ffed16p-1, 0x1.ffedeap-1, 0x1.ffeeb4p-1, 0x1.ffef76p-1, + 0x1.fff032p-1, 0x1.fff0e4p-1, 0x1.fff18ep-1, 0x1.fff232p-1, + 0x1.fff2d0p-1, 0x1.fff366p-1, 0x1.fff3f6p-1, 0x1.fff480p-1, + 0x1.fff504p-1, 0x1.fff582p-1, 0x1.fff5fcp-1, 0x1.fff670p-1, + 0x1.fff6dep-1, 0x1.fff74ap-1, 0x1.fff7aep-1, 0x1.fff810p-1, + 0x1.fff86cp-1, 0x1.fff8c6p-1, 0x1.fff91cp-1, 0x1.fff96cp-1, + 0x1.fff9bap-1, 0x1.fffa04p-1, 0x1.fffa4cp-1, 0x1.fffa90p-1, + 0x1.fffad0p-1, 0x1.fffb0ep-1, 0x1.fffb4ap-1, 0x1.fffb82p-1, + 0x1.fffbb8p-1, 0x1.fffbecp-1, 0x1.fffc1ep-1, 0x1.fffc4ep-1, + 0x1.fffc7ap-1, 0x1.fffca6p-1, 0x1.fffccep-1, 0x1.fffcf6p-1, + 0x1.fffd1ap-1, 0x1.fffd3ep-1, 0x1.fffd60p-1, 0x1.fffd80p-1, + 0x1.fffda0p-1, 0x1.fffdbep-1, 0x1.fffddap-1, 0x1.fffdf4p-1, + 0x1.fffe0ep-1, 0x1.fffe26p-1, 0x1.fffe3ep-1, 0x1.fffe54p-1, + 0x1.fffe68p-1, 0x1.fffe7ep-1, 0x1.fffe90p-1, 0x1.fffea2p-1, + 0x1.fffeb4p-1, 0x1.fffec4p-1, 0x1.fffed4p-1, 0x1.fffee4p-1, + 0x1.fffef2p-1, 0x1.ffff00p-1, 0x1.ffff0cp-1, 0x1.ffff18p-1, + 0x1.ffff24p-1, 0x1.ffff30p-1, 0x1.ffff3ap-1, 0x1.ffff44p-1, + 0x1.ffff4ep-1, 0x1.ffff56p-1, 0x1.ffff60p-1, 0x1.ffff68p-1, + 0x1.ffff70p-1, 0x1.ffff78p-1, 0x1.ffff7ep-1, 0x1.ffff84p-1, + 0x1.ffff8cp-1, 0x1.ffff92p-1, 0x1.ffff98p-1, 0x1.ffff9cp-1, + 0x1.ffffa2p-1, 0x1.ffffa6p-1, 0x1.ffffacp-1, 0x1.ffffb0p-1, + 0x1.ffffb4p-1, 0x1.ffffb8p-1, 0x1.ffffbcp-1, 0x1.ffffc0p-1, + 0x1.ffffc4p-1, 0x1.ffffc6p-1, 0x1.ffffcap-1, 0x1.ffffccp-1, + 0x1.ffffd0p-1, 0x1.ffffd2p-1, 0x1.ffffd4p-1, 0x1.ffffd6p-1, + 0x1.ffffd8p-1, 0x1.ffffdcp-1, 0x1.ffffdep-1, 0x1.ffffdep-1, + 0x1.ffffe0p-1, 0x1.ffffe2p-1, 0x1.ffffe4p-1, 0x1.ffffe6p-1, + 0x1.ffffe8p-1, 0x1.ffffe8p-1, 0x1.ffffeap-1, 0x1.ffffeap-1, + 0x1.ffffecp-1, 0x1.ffffeep-1, 0x1.ffffeep-1, 0x1.fffff0p-1, + 0x1.fffff0p-1, 0x1.fffff2p-1, 0x1.fffff2p-1, 0x1.fffff2p-1, + 0x1.fffff4p-1, 0x1.fffff4p-1, 0x1.fffff4p-1, 0x1.fffff6p-1, + 0x1.fffff6p-1, 0x1.fffff6p-1, 0x1.fffff8p-1, 0x1.fffff8p-1, + 0x1.fffff8p-1, 0x1.fffff8p-1, 0x1.fffffap-1, 0x1.fffffap-1, + 0x1.fffffap-1, 0x1.fffffap-1, 0x1.fffffap-1, 0x1.fffffap-1, + 0x1.fffffcp-1, 0x1.fffffcp-1, 0x1.fffffcp-1, 0x1.fffffcp-1, + 0x1.fffffcp-1, 0x1.fffffcp-1, 0x1.fffffcp-1, 0x1.fffffcp-1, + 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, + 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, + 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, + 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, 0x1.fffffep-1, + 0x1.fffffep-1, 0x1.fffffep-1, 0x1.000000p+0, 0x1.000000p+0, + 0x1.000000p+0, 0x1.000000p+0, 0x1.000000p+0, 0x1.000000p+0, + 0x1.000000p+0, 0x1.000000p+0, 0x1.000000p+0, 0x1.000000p+0, + 0x1.000000p+0, + }, + .scale = + { + 0x1.20dd76p+0, 0x1.20dd76p+0, 0x1.20cb68p+0, 0x1.20b4d8p+0, + 0x1.209546p+0, 0x1.206cb4p+0, 0x1.203b26p+0, 0x1.2000a0p+0, + 0x1.1fbd28p+0, 0x1.1f70c4p+0, 0x1.1f1b7ap+0, 0x1.1ebd56p+0, + 0x1.1e565cp+0, 0x1.1de698p+0, 0x1.1d6e14p+0, 0x1.1cecdcp+0, + 0x1.1c62fap+0, 0x1.1bd07cp+0, 0x1.1b3572p+0, 0x1.1a91e6p+0, + 0x1.19e5eap+0, 0x1.19318cp+0, 0x1.1874dep+0, 0x1.17aff0p+0, + 0x1.16e2d8p+0, 0x1.160da4p+0, 0x1.153068p+0, 0x1.144b3cp+0, + 0x1.135e30p+0, 0x1.12695ep+0, 0x1.116cd8p+0, 0x1.1068bap+0, + 0x1.0f5d16p+0, 0x1.0e4a08p+0, 0x1.0d2fa6p+0, 0x1.0c0e0ap+0, + 0x1.0ae550p+0, 0x1.09b590p+0, 0x1.087ee4p+0, 0x1.07416cp+0, + 0x1.05fd3ep+0, 0x1.04b27cp+0, 0x1.036140p+0, 0x1.0209a6p+0, + 0x1.00abd0p+0, 0x1.fe8fb0p-1, 0x1.fbbbbep-1, 0x1.f8dc0ap-1, + 0x1.f5f0cep-1, 0x1.f2fa4cp-1, 0x1.eff8c4p-1, 0x1.ecec78p-1, + 0x1.e9d5a8p-1, 0x1.e6b498p-1, 0x1.e38988p-1, 0x1.e054bep-1, + 0x1.dd167cp-1, 0x1.d9cf06p-1, 0x1.d67ea2p-1, 0x1.d32592p-1, + 0x1.cfc41ep-1, 0x1.cc5a8ap-1, 0x1.c8e91cp-1, 0x1.c5701ap-1, + 0x1.c1efcap-1, 0x1.be6872p-1, 0x1.bada5ap-1, 0x1.b745c6p-1, + 0x1.b3aafcp-1, 0x1.b00a46p-1, 0x1.ac63e8p-1, 0x1.a8b828p-1, + 0x1.a5074ep-1, 0x1.a1519ep-1, 0x1.9d9762p-1, 0x1.99d8dap-1, + 0x1.961650p-1, 0x1.925008p-1, 0x1.8e8646p-1, 0x1.8ab950p-1, + 0x1.86e96ap-1, 0x1.8316d6p-1, 0x1.7f41dcp-1, 0x1.7b6abcp-1, + 0x1.7791b8p-1, 0x1.73b714p-1, 0x1.6fdb12p-1, 0x1.6bfdf0p-1, + 0x1.681ff2p-1, 0x1.644156p-1, 0x1.60625cp-1, 0x1.5c8342p-1, + 0x1.58a446p-1, 0x1.54c5a6p-1, 0x1.50e79ep-1, 0x1.4d0a68p-1, + 0x1.492e42p-1, 0x1.455366p-1, 0x1.417a0cp-1, 0x1.3da26ep-1, + 0x1.39ccc2p-1, 0x1.35f940p-1, 0x1.32281ep-1, 0x1.2e5992p-1, + 0x1.2a8dcep-1, 0x1.26c508p-1, 0x1.22ff72p-1, 0x1.1f3d3cp-1, + 0x1.1b7e98p-1, 0x1.17c3b6p-1, 0x1.140cc4p-1, 0x1.1059eep-1, + 0x1.0cab62p-1, 0x1.09014cp-1, 0x1.055bd6p-1, 0x1.01bb2cp-1, + 0x1.fc3ee6p-2, 0x1.f511aap-2, 0x1.edeeeep-2, 0x1.e6d700p-2, + 0x1.dfca26p-2, 0x1.d8c8aap-2, 0x1.d1d2d0p-2, 0x1.cae8dap-2, + 0x1.c40b08p-2, 0x1.bd3998p-2, 0x1.b674c8p-2, 0x1.afbcd4p-2, + 0x1.a911f0p-2, 0x1.a27456p-2, 0x1.9be438p-2, 0x1.9561c8p-2, + 0x1.8eed36p-2, 0x1.8886b2p-2, 0x1.822e66p-2, 0x1.7be47ap-2, + 0x1.75a91ap-2, 0x1.6f7c6ap-2, 0x1.695e8cp-2, 0x1.634fa6p-2, + 0x1.5d4fd4p-2, 0x1.575f34p-2, 0x1.517de6p-2, 0x1.4bac00p-2, + 0x1.45e99cp-2, 0x1.4036d0p-2, 0x1.3a93b2p-2, 0x1.350052p-2, + 0x1.2f7cc4p-2, 0x1.2a0916p-2, 0x1.24a554p-2, 0x1.1f518ap-2, + 0x1.1a0dc6p-2, 0x1.14da0ap-2, 0x1.0fb662p-2, 0x1.0aa2d0p-2, + 0x1.059f5ap-2, 0x1.00ac00p-2, 0x1.f79184p-3, 0x1.edeb40p-3, + 0x1.e46530p-3, 0x1.daff4ap-3, 0x1.d1b982p-3, 0x1.c893cep-3, + 0x1.bf8e1cp-3, 0x1.b6a856p-3, 0x1.ade26cp-3, 0x1.a53c42p-3, + 0x1.9cb5bep-3, 0x1.944ec2p-3, 0x1.8c0732p-3, 0x1.83deeap-3, + 0x1.7bd5c8p-3, 0x1.73eba4p-3, 0x1.6c2056p-3, 0x1.6473b6p-3, + 0x1.5ce596p-3, 0x1.5575c8p-3, 0x1.4e241ep-3, 0x1.46f066p-3, + 0x1.3fda6cp-3, 0x1.38e1fap-3, 0x1.3206dcp-3, 0x1.2b48dap-3, + 0x1.24a7b8p-3, 0x1.1e233ep-3, 0x1.17bb2cp-3, 0x1.116f48p-3, + 0x1.0b3f52p-3, 0x1.052b0cp-3, 0x1.fe6460p-4, 0x1.f2a902p-4, + 0x1.e72372p-4, 0x1.dbd32ap-4, 0x1.d0b7a0p-4, 0x1.c5d04ap-4, + 0x1.bb1c98p-4, 0x1.b09bfcp-4, 0x1.a64de6p-4, 0x1.9c31c6p-4, + 0x1.92470ap-4, 0x1.888d1ep-4, 0x1.7f036cp-4, 0x1.75a960p-4, + 0x1.6c7e64p-4, 0x1.6381e2p-4, 0x1.5ab342p-4, 0x1.5211ecp-4, + 0x1.499d48p-4, 0x1.4154bcp-4, 0x1.3937b2p-4, 0x1.31458ep-4, + 0x1.297dbap-4, 0x1.21df9ap-4, 0x1.1a6a96p-4, 0x1.131e14p-4, + 0x1.0bf97ep-4, 0x1.04fc3ap-4, 0x1.fc4b5ep-5, 0x1.eeea8cp-5, + 0x1.e1d4d0p-5, 0x1.d508fap-5, 0x1.c885e0p-5, 0x1.bc4a54p-5, + 0x1.b05530p-5, 0x1.a4a54ap-5, 0x1.99397ap-5, 0x1.8e109cp-5, + 0x1.83298ep-5, 0x1.78832cp-5, 0x1.6e1c58p-5, 0x1.63f3f6p-5, + 0x1.5a08e8p-5, 0x1.505a18p-5, 0x1.46e66cp-5, 0x1.3dacd2p-5, + 0x1.34ac36p-5, 0x1.2be38cp-5, 0x1.2351c2p-5, 0x1.1af5d2p-5, + 0x1.12ceb4p-5, 0x1.0adb60p-5, 0x1.031ad6p-5, 0x1.f7182ap-6, + 0x1.e85c44p-6, 0x1.da0006p-6, 0x1.cc0180p-6, 0x1.be5ecep-6, + 0x1.b1160ap-6, 0x1.a4255ap-6, 0x1.978ae8p-6, 0x1.8b44e6p-6, + 0x1.7f5188p-6, 0x1.73af0cp-6, 0x1.685bb6p-6, 0x1.5d55ccp-6, + 0x1.529b9ep-6, 0x1.482b84p-6, 0x1.3e03d8p-6, 0x1.3422fep-6, + 0x1.2a875cp-6, 0x1.212f62p-6, 0x1.181984p-6, 0x1.0f443ep-6, + 0x1.06ae14p-6, 0x1.fcab14p-7, 0x1.ec7262p-7, 0x1.dcaf36p-7, + 0x1.cd5ecap-7, 0x1.be7e5ap-7, 0x1.b00b38p-7, 0x1.a202bep-7, + 0x1.94624ep-7, 0x1.87275ep-7, 0x1.7a4f6ap-7, 0x1.6dd7fep-7, + 0x1.61beaep-7, 0x1.56011cp-7, 0x1.4a9cf6p-7, 0x1.3f8ff6p-7, + 0x1.34d7dcp-7, 0x1.2a727ap-7, 0x1.205dacp-7, 0x1.169756p-7, + 0x1.0d1d6ap-7, 0x1.03ede2p-7, 0x1.f60d8ap-8, 0x1.e4cc4ap-8, + 0x1.d4143ap-8, 0x1.c3e1a6p-8, 0x1.b430ecp-8, 0x1.a4fe84p-8, + 0x1.9646f4p-8, 0x1.8806d8p-8, 0x1.7a3adep-8, 0x1.6cdfccp-8, + 0x1.5ff276p-8, 0x1.536fc2p-8, 0x1.4754acp-8, 0x1.3b9e40p-8, + 0x1.30499cp-8, 0x1.2553eep-8, 0x1.1aba78p-8, 0x1.107a8cp-8, + 0x1.06918cp-8, 0x1.f9f9d0p-9, 0x1.e77448p-9, 0x1.d58da6p-9, + 0x1.c4412cp-9, 0x1.b38a3ap-9, 0x1.a36454p-9, 0x1.93cb12p-9, + 0x1.84ba30p-9, 0x1.762d84p-9, 0x1.682100p-9, 0x1.5a90b0p-9, + 0x1.4d78bcp-9, 0x1.40d564p-9, 0x1.34a306p-9, 0x1.28de12p-9, + 0x1.1d8318p-9, 0x1.128ebap-9, 0x1.07fdb4p-9, 0x1.fb99b8p-10, + 0x1.e7f232p-10, 0x1.d4fed8p-10, 0x1.c2b9d0p-10, 0x1.b11d70p-10, + 0x1.a02436p-10, 0x1.8fc8c8p-10, 0x1.8005f0p-10, 0x1.70d6a4p-10, + 0x1.6235fcp-10, 0x1.541f34p-10, 0x1.468daep-10, 0x1.397ceep-10, + 0x1.2ce898p-10, 0x1.20cc76p-10, 0x1.15246ep-10, 0x1.09ec86p-10, + 0x1.fe41cep-11, 0x1.e97ba4p-11, 0x1.d57f52p-11, 0x1.c245d4p-11, + 0x1.afc85ep-11, 0x1.9e0058p-11, 0x1.8ce75ep-11, 0x1.7c7744p-11, + 0x1.6caa0ep-11, 0x1.5d79ecp-11, 0x1.4ee142p-11, 0x1.40daa4p-11, + 0x1.3360ccp-11, 0x1.266ea8p-11, 0x1.19ff46p-11, 0x1.0e0de8p-11, + 0x1.0295f0p-11, 0x1.ef25d4p-12, 0x1.da0110p-12, 0x1.c5b542p-12, + 0x1.b23a5ap-12, 0x1.9f8894p-12, 0x1.8d986ap-12, 0x1.7c629ap-12, + 0x1.6be022p-12, 0x1.5c0a38p-12, 0x1.4cda54p-12, 0x1.3e4a24p-12, + 0x1.305390p-12, 0x1.22f0b4p-12, 0x1.161be4p-12, 0x1.09cfa4p-12, + 0x1.fc0d56p-13, 0x1.e577bcp-13, 0x1.cfd4a6p-13, 0x1.bb1a96p-13, + 0x1.a74068p-13, 0x1.943d4ap-13, 0x1.8208bcp-13, 0x1.709a8ep-13, + 0x1.5feadap-13, 0x1.4ff208p-13, 0x1.40a8c2p-13, 0x1.3207fcp-13, + 0x1.2408eap-13, 0x1.16a502p-13, 0x1.09d5f8p-13, 0x1.fb2b7ap-14, + 0x1.e3bcf4p-14, 0x1.cd5528p-14, 0x1.b7e946p-14, 0x1.a36eecp-14, + 0x1.8fdc1cp-14, 0x1.7d2738p-14, 0x1.6b4702p-14, 0x1.5a329cp-14, + 0x1.49e178p-14, 0x1.3a4b60p-14, 0x1.2b6876p-14, 0x1.1d3120p-14, + 0x1.0f9e1cp-14, 0x1.02a868p-14, 0x1.ec929ap-15, 0x1.d4f4b4p-15, + 0x1.be6abcp-15, 0x1.a8e8ccp-15, 0x1.94637ep-15, 0x1.80cfdcp-15, + 0x1.6e2368p-15, 0x1.5c540cp-15, 0x1.4b581cp-15, 0x1.3b2652p-15, + 0x1.2bb5ccp-15, 0x1.1cfe02p-15, 0x1.0ef6c4p-15, 0x1.019842p-15, + 0x1.e9b5e8p-16, 0x1.d16f58p-16, 0x1.ba4f04p-16, 0x1.a447b8p-16, + 0x1.8f4cccp-16, 0x1.7b5224p-16, 0x1.684c22p-16, 0x1.562facp-16, + 0x1.44f21ep-16, 0x1.34894ap-16, 0x1.24eb72p-16, 0x1.160f44p-16, + 0x1.07ebd2p-16, 0x1.f4f12ep-17, 0x1.db5ad0p-17, 0x1.c304f0p-17, + 0x1.abe09ep-17, 0x1.95df98p-17, 0x1.80f43ap-17, 0x1.6d1178p-17, + 0x1.5a2ae0p-17, 0x1.483488p-17, 0x1.372310p-17, 0x1.26eb9ep-17, + 0x1.1783cep-17, 0x1.08e1bap-17, 0x1.f5f7d8p-18, 0x1.db92b6p-18, + 0x1.c282cep-18, 0x1.aab7acp-18, 0x1.94219cp-18, 0x1.7eb1a2p-18, + 0x1.6a5972p-18, 0x1.570b6ap-18, 0x1.44ba86p-18, 0x1.335a62p-18, + 0x1.22df2ap-18, 0x1.133d96p-18, 0x1.046aeap-18, 0x1.ecb9d0p-19, + 0x1.d21398p-19, 0x1.b8d094p-19, 0x1.a0df10p-19, 0x1.8a2e26p-19, + 0x1.74adc8p-19, 0x1.604ea8p-19, 0x1.4d0232p-19, 0x1.3aba86p-19, + 0x1.296a70p-19, 0x1.190562p-19, 0x1.097f62p-19, 0x1.f59a20p-20, + 0x1.d9c736p-20, 0x1.bf716cp-20, 0x1.a6852cp-20, 0x1.8eefd8p-20, + 0x1.789fb8p-20, 0x1.6383f8p-20, 0x1.4f8c96p-20, 0x1.3caa62p-20, + 0x1.2acee2p-20, 0x1.19ec60p-20, 0x1.09f5d0p-20, 0x1.f5bd96p-21, + 0x1.d9371ep-21, 0x1.be41dep-21, 0x1.a4c89ep-21, 0x1.8cb738p-21, + 0x1.75fa8ep-21, 0x1.608078p-21, 0x1.4c37c0p-21, 0x1.39100ep-21, + 0x1.26f9e0p-21, 0x1.15e682p-21, 0x1.05c804p-21, 0x1.ed2254p-22, + 0x1.d06ad6p-22, 0x1.b551c8p-22, 0x1.9bc0a0p-22, 0x1.83a200p-22, + 0x1.6ce1aap-22, 0x1.576c72p-22, 0x1.43302cp-22, 0x1.301ba2p-22, + 0x1.1e1e86p-22, 0x1.0d2966p-22, 0x1.fa5b50p-23, 0x1.dc3ae4p-23, + 0x1.bfd756p-23, 0x1.a517dap-23, 0x1.8be4f8p-23, 0x1.74287ep-23, + 0x1.5dcd66p-23, 0x1.48bfd4p-23, 0x1.34ecf8p-23, 0x1.224310p-23, + 0x1.10b148p-23, + }, +}; + +static inline __m256 erf_ps(__m256 x) { + __m256i zero = _mm256_setzero_si256(); + __m256 a = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), x); + /* |x| > 1/64 - 1/512. */ + __m256 gt_min_mask = + _mm256_cmp_ps(a, _mm256_set1_ps(0x1.cp-7f), _CMP_GT_OS); + int gt_min_mask_as_int; + std::memcpy(>_min_mask_as_int, >_min_mask, 1); + __m256 tmp_i = _mm256_mul_ps(a, _mm256_set1_ps(128.f)); + + __m256i signed_i = _mm256_cvttps_epi32( + _mm256_round_ps(tmp_i, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + const __m256i mask = _mm256_set1_epi32(0xFFFFFFFF); + __m256i i = _mm256_and_si256(signed_i, mask); + i = _mm256_blendv_epi8(zero, i, _mm256_castps_si256(gt_min_mask)); + i = _mm256_min_epu32(i, _mm256_set1_epi32(512)); + __m256 tmp_r = _mm256_cvtepi32_ps(i); + i = _mm256_mullo_epi32(i, _mm256_set1_epi32(1)); + __m256 r = _mm256_mul_ps(tmp_r, _mm256_set1_ps(1.f / 128)); + __m256 erfr = _mm256_i32gather_ps(__sv_erff_data.erf, i, sizeof(float)); + __m256 scale = _mm256_i32gather_ps(__sv_erff_data.scale, i, sizeof(float)); + __m256 ge_max_mask = _mm256_cmp_ps(a, _mm256_set1_ps(3.9375f), _CMP_GE_OS); + std::memcpy(>_min_mask_as_int, &ge_max_mask, 1); + /* erf(x) ~ erf(r) + scale * d * (1 - r * d - 1/3 * d^2). */ + __m256 d = _mm256_sub_ps(a, r); + __m256 d2 = _mm256_mul_ps(d, d); + __m256 y = _mm256_fmadd_ps(d, _mm256_set1_ps(0x1.555556p-2f), r); + y = _mm256_fnmadd_ps(y, d2, d); + y = _mm256_fmadd_ps(y, scale, erfr); + y = _mm256_blendv_ps(y, _mm256_set1_ps(1.f), ge_max_mask); + + __m256 sign_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x80000000)); + __m256 x_sign = _mm256_and_ps(x, sign_mask); + __m256 y_abs = _mm256_andnot_ps(sign_mask, y); + y = _mm256_or_ps(y_abs, x_sign); + + return y; +} + static inline __m256 atan2256_ps(__m256 y, __m256 x) { // Reference: https://mazzo.li/posts/vectorized-atan2.html diff --git a/src/Native/include/nncase/ntt/arch/x86_64/primitive_ops.h b/src/Native/include/nncase/ntt/arch/x86_64/primitive_ops.h index 26c553e989..8397a8b14f 100644 --- a/src/Native/include/nncase/ntt/arch/x86_64/primitive_ops.h +++ b/src/Native/include/nncase/ntt/arch/x86_64/primitive_ops.h @@ -27,7 +27,7 @@ namespace nncase::ntt::ops { template <> struct abs> { ntt::vector operator()(const ntt::vector &v) const noexcept { - return abs256_ps(v); + return _mm256_andnot_ps(_mm256_set1_ps(-0.0f), v); } }; @@ -239,6 +239,14 @@ template <> struct cosh> { } }; +// erf +template <> struct erf> { + ntt::vector + operator()(const ntt::vector &v) const noexcept { + return erf_ps(v); + } +}; + // exp template <> struct exp> { ntt::vector @@ -450,10 +458,121 @@ template <> struct swish> { }; // tanh +#define LOG2_INV 0x1.71547652b82fep+0 +#define LOG2_HI 0x1.62e42fefa39efp-1 +#define LOG2_LO 0x1.abc9e3b39803fp-56 template <> struct tanh> { ntt::vector operator()(const ntt::vector &v) const noexcept { +#if 0 return tanh256_ps(v); +#else + constexpr float fp_posZero = 0.0f; + constexpr float fp_posOne = 1.f; + + __m256 zero = _mm256_set1_ps(0.0f); + + // 创建一个所有元素都是 1.0f 的向量 + __m256 one = _mm256_set1_ps(1.0f); + + __m256 vx = _mm256_and_ps( + v, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF))); // vfsgnj_vf + vx = _mm256_min_ps(vx, _mm256_set1_ps(20.0f)); // vfmin_vf + vx = _mm256_mul_ps(vx, _mm256_set1_ps(-2.0f)); // vfmul_vf + __m256 n_flt = _mm256_mul_ps(vx, _mm256_set1_ps(LOG2_INV)); // vfmul_vf + __m256i n = _mm256_cvtps_epi32(n_flt); // vfcvt_x_f_v + n_flt = _mm256_cvtepi32_ps(n); // vfcvt_f_x_v + __m256i u = _mm256_add_epi32(n, _mm256_set1_epi32(127)); // vadd_vx + + __m256 r_delta = + _mm256_fnmadd_ps(_mm256_set1_ps(LOG2_HI), n_flt, vx); // vfnmsac_vf + u = _mm256_slli_epi32(u, 23); // vsll_vx + __m256 r = _mm256_fnmadd_ps(_mm256_set1_ps(LOG2_LO), n_flt, + r_delta); // vfnmsac_vf + __m256 s = _mm256_castsi256_ps(u); // vreinterpret_v_i32m_f32m + __m256i s_is_small = + _mm256_cmpgt_epi32(_mm256_set1_epi32(-(23 + 1)), n); // vmsle_vx + + r_delta = _mm256_sub_ps(r_delta, r); // vfsub_vv + // std::endl; + __m256 s_head = + _mm256_blendv_ps(s, _mm256_set1_ps(fp_posZero), + _mm256_castsi256_ps(s_is_small)); // vfmerge_vfm + r_delta = _mm256_fnmadd_ps(_mm256_set1_ps(LOG2_LO), n_flt, + r_delta); // vfnmsac_vf + // std::endl; + + __m256 rsq = _mm256_mul_ps(r, r); // vfmul_vv + __m256 s_tail = _mm256_blendv_ps( + zero, s, _mm256_castsi256_ps(s_is_small)); // vmerge_vvm + + __m256 rcube = _mm256_mul_ps(rsq, r); // vfmul_vv + __m256 c0 = _mm256_set1_ps(0x1.71ddef82f4beep-19f); // vfmv_v_f + __m256 c1 = _mm256_set1_ps(0x1.a01a01b32b633p-13f); // vfmv_v_f + __m256 c2 = _mm256_set1_ps(0x1.111111110ef6ap-7f); // vfmv_v_f + __m256 c3 = _mm256_set1_ps(0x1.555555555555ap-3f); // vfmv_v_f + __m256 c4 = _mm256_set1_ps(0x1.a019b37a2b3dfp-16f); // vfmv_v_f + __m256 c5 = _mm256_set1_ps(0x1.6c16c17a09506p-10f); // vfmv_v_f + __m256 c6 = _mm256_set1_ps(0x1.5555555553aefp-5f); // vfmv_v_f + + __m256 p_even = _mm256_moveldup_ps(rsq); // vmv_v_v + p_even = _mm256_fmadd_ps(p_even, _mm256_set1_ps(0x1.af6eacd796f0bp-26f), + c0); // vfmadd_vf + p_even = _mm256_fmadd_ps(p_even, rsq, c1); // vfmadd_vv + p_even = _mm256_fmadd_ps(p_even, rsq, c2); // vfmadd_vv + p_even = _mm256_fmadd_ps(p_even, rsq, c3); // vfmadd_vv + + __m256 p_odd = _mm256_moveldup_ps(rsq); // vmv_v_v + p_odd = _mm256_fmadd_ps(p_odd, _mm256_set1_ps(0x1.289788d8bdadfp-22f), + c4); // vfmadd_vf + p_odd = _mm256_fmadd_ps(p_odd, rsq, c5); // vfmadd_vv + p_odd = _mm256_fmadd_ps(p_odd, rsq, c6); // vfmadd_vv + __m256 poly = _mm256_fmadd_ps(p_odd, r, p_even); // vfmadd_vv + + __m256 r_prime = _mm256_mul_ps(r, _mm256_set1_ps(0.5f)); // vfmul_vf + // std::endl; + __m256 B = _mm256_fmadd_ps(r, r_prime, r); // vfmadd_vv + __m256 b = _mm256_sub_ps(r, B); // vfsub_vv + b = _mm256_fmadd_ps(r, r_prime, b); // vfmacc_vv (vfmadd_ps in AVX2) + __m256 c = _mm256_fmadd_ps(r, r_delta, r_delta); // vfmadd_vv + b = _mm256_add_ps(b, c); // vfadd_vv + poly = _mm256_fmadd_ps(poly, rcube, b); // vfmadd_vv + + __m256 Z = _mm256_add_ps(s_head, _mm256_set1_ps(fp_posOne)); // vfadd_vf + __m256 D_tmp = _mm256_fmadd_ps(B, s, Z); // vfmadd_vv + __m256 d_tmp = _mm256_sub_ps(Z, D_tmp); // vfsub_vv + d_tmp = _mm256_fmadd_ps(s, B, d_tmp); // vfmacc_vv (vfmadd_ps in AVX2) + d_tmp = _mm256_add_ps(d_tmp, s_tail); // vfadd_vv + d_tmp = + _mm256_fmadd_ps(s, poly, d_tmp); // vfmacc_vv (vfmadd_ps in AVX2) + + __m256 D = _mm256_add_ps(D_tmp, d_tmp); // vfadd_vv + __m256 d = _mm256_sub_ps(D_tmp, D); // vfsub_vv + d = _mm256_add_ps(d, d_tmp); // vfadd_vv + __m256 E = _mm256_div_ps(_mm256_set1_ps(fp_posOne), D); // vfrdiv_vf + __m256 e = _mm256_fnmadd_ps(E, D, one); // vfnmsub_vv + e = _mm256_fnmadd_ps(E, d, e); // vfnmsac_vv + e = _mm256_mul_ps(e, _mm256_rcp_ps(D)); // vfmul_vv with vfrec7_v + + Z = _mm256_sub_ps(_mm256_set1_ps(fp_posOne), s_head); // vfrsub_vf + __m256 Numer = _mm256_fnmadd_ps(B, s, Z); // vfnmsub_vv + __m256 numer = _mm256_sub_ps(Z, Numer); // vfsub_vv + numer = _mm256_fnmadd_ps(s, B, numer); // vfnmsac_vv + numer = _mm256_sub_ps(numer, s_tail); // vfsub_vv + numer = _mm256_fnmadd_ps(s, poly, numer); // vfnmsac_vv + + __m256 vy = _mm256_mul_ps(e, numer); // vfmul_vv + vy = _mm256_fmadd_ps(Numer, e, vy); // vfmacc_vv + vy = _mm256_fmadd_ps(numer, E, vy); // vfmacc_vv + vy = _mm256_fmadd_ps(Numer, E, vy); // vfmacc_vv + __m256 sign_v = + _mm256_and_ps(v, _mm256_set1_ps(-0.0f)); // Extract sign from `v` + __m256 magnitude_vy = _mm256_andnot_ps( + _mm256_set1_ps(-0.0f), vy); // Extract magnitude from `vy` + __m256 result = _mm256_or_ps(magnitude_vy, sign_v); + return result; + +#endif } }; diff --git a/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h b/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h index ad04bb0606..f49ae7be9a 100644 --- a/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h +++ b/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h @@ -14,9 +14,7 @@ */ #pragma once #include "../../ukernels.h" -#include "arch_types.h" #include "nncase/ntt/vector.h" -#include namespace nncase::ntt::ukernels { template diff --git a/src/Native/include/nncase/ntt/kernels/matmul.h b/src/Native/include/nncase/ntt/kernels/matmul.h index 7da9908890..aadf3a8c81 100644 --- a/src/Native/include/nncase/ntt/kernels/matmul.h +++ b/src/Native/include/nncase/ntt/kernels/matmul.h @@ -166,234 +166,12 @@ class matmul_impl{}); - - // 1. pack M & N - if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mn && - m0_subtile) { - using TSubOutElem = ntt::vector; - TSubOutElem c0_tmp[m0_subtile][N0Tile]; - - for (size_t sm1 = 0; sm1 < TOutElem::shape()[0]; - sm1 += m0_subtile) { - ntt::apply(fixed_shape{}, [&](auto index) { - c0_tmp[index[0]][index[1]] = - AccumulateC ? c0(0, index[1])(sm1 + index[0]) - : TSubOutElem{}; - }); - - for (size_t k1 = 0; k1 < K; k1++) { - outer_product(a, b, c0_tmp, m1, k1, n1, - sm1); - } - - ntt::apply(fixed_shape{}, [&](auto index) { - c0(0, index[1])(sm1 + index[0]) = - c0_tmp[index[0]][index[1]]; - }); - } - } - // 2. pack K & KN - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_kn) { - using TLhsElem = std::remove_const_t; - - TOutElem c0_tmp[M0Tile][N0Tile]; - ntt::apply(c0.shape(), [&](auto index) { - c0_tmp[index[0]][index[1]] = - AccumulateC ? c0(index) : TOutElem{}; - }); - - for (size_t k1 = 0; k1 < K; k1++) { - for (size_t sk1 = 0; sk1 < TLhsElem::shape()[0]; sk1++) { - outer_product(a, b, c0_tmp, m1, k1, n1, 0, - sk1); - } - } - - ntt::apply(c0.shape(), [&](auto index) { - c0(index) = c0_tmp[index[0]][index[1]]; - }); - } - // 3. pack MK & KN - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mkn && - m0_subtile) { - using TLhsElem = std::remove_const_t; - using TSubOutElem = ntt::vector; - - TSubOutElem c0_tmp[m0_subtile][N0Tile]; - - for (size_t sm1 = 0; sm1 < TOutElem::shape()[0]; - sm1 += m0_subtile) { - ntt::apply(fixed_shape{}, [&](auto index) { - c0_tmp[index[0]][index[1]] = - AccumulateC ? c0(0, index[1])(sm1 + index[0]) - : TSubOutElem{}; - }); - - for (size_t k1 = 0; k1 < K; k1++) { - for (size_t sk1 = 0; sk1 < TLhsElem::shape()[0]; sk1++) { - outer_product(a, b, c0_tmp, m1, k1, n1, - sm1, sk1); - } - } - - ntt::apply(fixed_shape{}, [&](auto index) { - c0(0, index[1])(sm1 + index[0]) = - c0_tmp[index[0]][index[1]]; - }); - } - } - // Other packs - else { - TOutElem c0_tmp[M0Tile][N0Tile]; - ntt::apply(c0.shape(), [&](auto index) { - c0_tmp[index[0]][index[1]] = - AccumulateC ? c0(index) : TOutElem{}; - }); - - for (size_t k1 = 0; k1 < K; k1++) { - outer_product(a, b, c0_tmp, m1, k1, n1); - } - - ntt::apply(c0.shape(), [&](auto index) { - c0(index) = c0_tmp[index[0]][index[1]]; - }); - } - } - - template - void outer_product(const TA &a, const TB &b, TC &c0_tmp, size_t m1, - size_t k1, size_t n1, size_t sm1 = 0, size_t sk1 = 0) { auto a1 = - a.view(make_ranked_shape(m1, k1), make_ranked_shape(M0Tile, 1)); + a.view(make_ranked_shape(m1, 0), make_ranked_shape(M0Tile, K)); auto b1 = - b.view(make_ranked_shape(k1, n1), make_ranked_shape(1, N0Tile)); - - using TLhsElem = std::remove_const_t; - using TRhsElem = std::remove_const_t; - - // 1. pack M & N - if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mn && - m0_subtile) { - using TSubLhsElem = typename TLhsElem::element_type; - TSubLhsElem a0_tmp[m0_subtile]; - TRhsElem b0_tmp[N0Tile]; + b.view(make_ranked_shape(0, n1), make_ranked_shape(K, N0Tile)); - ntt::apply(fixed_shape{}, [&](auto index) { - a0_tmp[index[0]] = a1(0, 0)(sm1 + index[0]); - }); - ntt::apply(fixed_shape{}, - [&](auto index) { b0_tmp[index[0]] = b1(0, index[0]); }); - - for (size_t n = 0; n < N0Tile; n++) { - for (size_t m = 0; m < m0_subtile; m++) { - mul_add(a0_tmp[m], b0_tmp[n], c0_tmp[m][n]); - } - } - } - // 2. pack K & KN - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_kn) { - using TSubLhsElem = typename TLhsElem::element_type; - using TSubRhsElem = ntt::vector; - TSubLhsElem a0_tmp[M0Tile]; - TSubRhsElem b0_tmp[N0Tile]; - - ntt::apply(fixed_shape{}, [&](auto index) { - a0_tmp[index[0]] = a1(index[0], 0)(sk1); - }); - ntt::apply(fixed_shape{}, [&](auto index) { - b0_tmp[index[0]] = b1(0, index[0])(sk1); - }); - - for (size_t n = 0; n < N0Tile; n++) { - for (size_t m = 0; m < M0Tile; m++) { - mul_add(a0_tmp[m], b0_tmp[n], c0_tmp[m][n]); - } - } - } - // 1. pack MK & KN - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mkn && - m0_subtile) { - using TSubLhsElem = typename TLhsElem::element_type; - using TSubRhsElem = ntt::vector; - TSubLhsElem a0_tmp[m0_subtile]; - TSubRhsElem b0_tmp[N0Tile]; - - ntt::apply(fixed_shape{}, [&](auto index) { - a0_tmp[index[0]] = a1(0, 0)(sm1 + index[0], sk1); - }); - ntt::apply(fixed_shape{}, [&](auto index) { - b0_tmp[index[0]] = b1(0, index[0])(sk1); - }); - - for (size_t n = 0; n < N0Tile; n++) { - for (size_t m = 0; m < m0_subtile; m++) { - auto &output = c0_tmp[m][n]; - auto value = ntt::outer_product(a0_tmp[m], b0_tmp[n]); - output = output + value; - } - } - } - // Other packs - else { - TLhsElem a0_tmp[M0Tile]; - TRhsElem b0_tmp[N0Tile]; - - ntt::apply(fixed_shape{}, - [&](auto index) { a0_tmp[index[0]] = a1(index[0], 0); }); - ntt::apply(fixed_shape{}, - [&](auto index) { b0_tmp[index[0]] = b1(0, index[0]); }); - - for (size_t n = 0; n < N0Tile; n++) { - for (size_t m = 0; m < M0Tile; m++) { - mul_add(a0_tmp[m], b0_tmp[n], c0_tmp[m][n]); - } - } - } - } - - template - void mul_add(const TLhsElem &lhs, const TRhsElem &rhs, TOutElem &output) { - // 1. 0D-packing - if constexpr (pack_kind == ukernels::mamtul_pack_kind::no_pack) { - output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); - } - // 2. 1D-packing - // 2.1. pack M - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_m) { - output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); - } - // 2.2. pack K - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_k) { - auto value = ntt::inner_product(lhs, rhs); - output = AccC ? output + value : value; - } - // 2.3. pack N - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_n) { - output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); - } - // 2.4. pack M & N - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mn || - pack_kind == ukernels::mamtul_pack_kind::pack_kn) { - auto value = ntt::outer_product(lhs, rhs); - output = AccC ? output + value : value; - } - // 3.1. pack MK & K - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mk) { - for (size_t m = 0; m < lhs.shape()[0]; m++) { - auto value = ntt::inner_product(lhs(m), rhs); - output(m) = AccC ? output(m) + value : value; - } - } - // 3.2. pack MK & KN - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mkn) { - output = ntt::mma(lhs, rhs, output); - } else { - static_assert(sizeof(TLhsElem) == 0, "Unsupported packing."); - } + ntt::u_matmul(a1, b1, c0, K); } }; } // namespace detail diff --git a/src/Native/include/nncase/ntt/ntt.h b/src/Native/include/nncase/ntt/ntt.h index 7422fe888c..7c20869989 100644 --- a/src/Native/include/nncase/ntt/ntt.h +++ b/src/Native/include/nncase/ntt/ntt.h @@ -56,4 +56,5 @@ #include "arch/riscv64/arch_types.h" #include "arch/riscv64/primitive_ops.h" #include "arch/riscv64/tensor_ops.h" +#include "arch/riscv64/ukernels.h" #endif diff --git a/src/Native/include/nncase/ntt/primitive_ops.h b/src/Native/include/nncase/ntt/primitive_ops.h index 979bfa765e..45c68cc48b 100644 --- a/src/Native/include/nncase/ntt/primitive_ops.h +++ b/src/Native/include/nncase/ntt/primitive_ops.h @@ -31,6 +31,17 @@ enum class reduce_op { namespace ops { +/** + * @defgroup Load/Store operation functors + * @{ + */ + +template struct store { + constexpr void operator()(TDest &dest, const TSource &v) const noexcept { + dest = v; + } +}; + /** * @defgroup Unary operation functors * @{ @@ -312,6 +323,12 @@ template struct clamp { constexpr auto name(const T &v, TResult init_value) noexcept { \ return ntt::reduce(v, init_value); \ } + +template +constexpr void store(TDest &dest, const TSource &v) noexcept { + ops::store, std::decay_t>()(dest, v); +} + #define NTT_DEFINE_COMPARE_FUNC_IMPL(op) \ template \ constexpr auto op(const T &v1, const T &v2) noexcept { \ diff --git a/src/Native/include/nncase/ntt/ukernels.h b/src/Native/include/nncase/ntt/ukernels.h index 2a79d325a9..f8e4d51f95 100644 --- a/src/Native/include/nncase/ntt/ukernels.h +++ b/src/Native/include/nncase/ntt/ukernels.h @@ -13,159 +13,7 @@ * limitations under the License. */ #pragma once -#include "apply.h" -#include "primitive_ops.h" -#include "tensor.h" -#include "tensor_traits.h" - -namespace nncase::ntt::ukernels { -template -class u_pack { - public: - constexpr void operator()(const TIn *input, TOut *output) noexcept { - for (size_t j = 0; j < N; j++) { - for (size_t i = 0; i < M; i++) { - output[j](i) = input[i * MStrides + j]; - } - } - - if constexpr (M < TOut::shape_type::length()) { - for (size_t j = 0; j < N; j++) { - for (size_t i = M; i < TOut::shape_type::length(); i++) { - output[j](i) = (TIn)0; - } - } - } - } -}; - -template struct reduce_to_binary_type; - -template <> struct reduce_to_binary_type { - template using type = ops::add; -}; - -template <> struct reduce_to_binary_type { - template using type = ops::min; -}; - -template <> struct reduce_to_binary_type { - template using type = ops::max; -}; - -template <> struct reduce_to_binary_type { - template using type = ops::add; -}; - -template <> struct reduce_to_binary_type { - template using type = ops::mul; -}; - -template struct u_reduce_policy { - static constexpr size_t unroll = 2; -}; - -template struct u_reduce { - public: - constexpr T operator()(const T *input, size_t input_stride, size_t count, - T init_value) noexcept { - using binary_op_t = - typename reduce_to_binary_type::template type; - using policy_t = u_reduce_policy; - constexpr auto unroll = policy_t::unroll; - - if (count / unroll) { - T temp[unroll]; -#if 1 - for (size_t i = 0; i < unroll; i++) { - temp[i] = *input; - input += input_stride; - count--; - } - - while (count / unroll) { - for (size_t i = 0; i < unroll; i++) { - temp[i] = binary_op_t()(temp[i], *input); - input += input_stride; - count--; - } - } - - init_value = binary_op_t()(init_value, tree_reduce(temp)); -#else - while (count / unroll) { - for (size_t i = 0; i < unroll; i++) { - temp[i] = *input; - input += input_stride; - count--; - } - init_value = - binary_op_t()(init_value, tree_reduce(temp)); - } -#endif - } - - for (size_t i = 0; i < count; i++) { - init_value = binary_op_t()(init_value, *input); - input += input_stride; - } - return init_value; - } - - template constexpr T tree_reduce(T *input) noexcept { - using binary_op_t = - typename reduce_to_binary_type::template type; - if constexpr (N == 2) { - return binary_op_t()(input[0], input[1]); - } else { - return binary_op_t()(tree_reduce(input), - tree_reduce(input + N / 2)); - } - } -}; - -enum class mamtul_pack_kind { - unknown, - no_pack, - pack_m, - pack_k, - pack_n, - pack_mn, - pack_mk, - pack_kn, - pack_mkn, -}; - -template -struct u_matmul_policy { - static constexpr size_t m0_tile = 1; - static constexpr size_t n0_tile = 1; - static constexpr size_t m0_subtile = 0; -}; -} // namespace nncase::ntt::ukernels - -namespace nncase::ntt { -template -constexpr void u_pack(const TIn *input, TOut *output) noexcept { - ukernels::u_pack, - std::decay_t> - impl; - impl(input, output); -} - -template -constexpr T u_reduce(const T *input, size_t input_stride, size_t count, - T init_value) noexcept { - ukernels::u_reduce impl; - return impl(input, input_stride, count, init_value); -} - -// template -// constexpr void u_matmul(const TLhsElem *&lhs, const TRhsElem *&rhs, -// TOutElem *output, size_t M, size_t N, size_t K, -// size_t lhs_stride, size_t rhs_stride, -// size_t out_stride) noexcept { - -// } -} // namespace nncase::ntt +#include "ukernels/u_matmul.h" +#include "ukernels/u_mul_add.h" +#include "ukernels/u_pack.h" +#include "ukernels/u_reduce.h" diff --git a/src/Native/include/nncase/ntt/ukernels/u_matmul.h b/src/Native/include/nncase/ntt/ukernels/u_matmul.h new file mode 100644 index 0000000000..dc5052ab8e --- /dev/null +++ b/src/Native/include/nncase/ntt/ukernels/u_matmul.h @@ -0,0 +1,273 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../apply.h" +#include "nncase/ntt/primitive_ops.h" +#include "nncase/ntt/shape.h" +#include "u_mul_add.h" + +namespace nncase::ntt { +namespace ukernels { +template +struct u_matmul_policy { + static constexpr size_t m0_tile = 1; + static constexpr size_t n0_tile = 1; + static constexpr size_t m0_subtile = 0; +}; + +template +struct u_matmul_generic { + template + constexpr void operator()(const TA &a, const TB &b, TC &c0, + size_t K) noexcept { + TOutElem c0_tmp[M0Tile][N0Tile]; + ntt::apply(c0.shape(), [&](auto index) { + c0_tmp[index[0]][index[1]] = AccumulateC ? c0(index) : TOutElem{}; + }); + + for (size_t k1 = 0; k1 < K; k1++) { + auto a0 = + a.view(make_ranked_shape(0, k1), fixed_shape{}); + auto b0 = + b.view(make_ranked_shape(k1, 0), fixed_shape<1, N0Tile>{}); + TLhsElem a0_tmp[M0Tile]; + TRhsElem b0_tmp[N0Tile]; + + ntt::apply(fixed_shape{}, + [&](auto index) { a0_tmp[index[0]] = a0(index[0], 0); }); + ntt::apply(fixed_shape{}, + [&](auto index) { b0_tmp[index[0]] = b0(0, index[0]); }); + + for (size_t n = 0; n < N0Tile; n++) { + for (size_t m = 0; m < M0Tile; m++) { + u_mul_add(a0_tmp[m], b0_tmp[n], + c0_tmp[m][n]); + } + } + } + + ntt::apply(c0.shape(), [&](auto index) { + ntt::store(c0(index), c0_tmp[index[0]][index[1]]); + }); + } +}; + +template +struct u_matmul : u_matmul_generic {}; + +template +struct u_matmul { + template + constexpr void operator()(const TA &a, const TB &b, TC &c0, + size_t K) noexcept { + using TSubOutElem = ntt::vector; + using policy_t = + ntt::ukernels::u_matmul_policy; + constexpr auto m0_subtile = policy_t::m0_subtile; + + if constexpr (m0_subtile) { + TSubOutElem c0_tmp[m0_subtile][N0Tile]; + + for (size_t sm1 = 0; sm1 < TOutElem::shape()[0]; + sm1 += m0_subtile) { + ntt::apply(fixed_shape{}, [&](auto index) { + c0_tmp[index[0]][index[1]] = + AccumulateC ? c0(0, index[1])(sm1 + index[0]) + : TSubOutElem{}; + }); + + for (size_t k1 = 0; k1 < K; k1++) { + using TSubLhsElem = typename TLhsElem::element_type; + TSubLhsElem a0_tmp[m0_subtile]; + TRhsElem b0_tmp[N0Tile]; + + auto a0 = a.view(make_ranked_shape(0, k1), + fixed_shape{}); + auto b0 = b.view(make_ranked_shape(k1, 0), + fixed_shape<1, N0Tile>{}); + + ntt::apply(fixed_shape{}, [&](auto index) { + a0_tmp[index[0]] = a0(0, 0)(sm1 + index[0]); + }); + ntt::apply(fixed_shape{}, [&](auto index) { + b0_tmp[index[0]] = b0(0, index[0]); + }); + + for (size_t n = 0; n < N0Tile; n++) { + for (size_t m = 0; m < m0_subtile; m++) { + c0_tmp[m][n] = ntt::mul_add(a0_tmp[m], b0_tmp[n], + c0_tmp[m][n]); + } + } + } + + ntt::apply(fixed_shape{}, [&](auto index) { + ntt::store(c0(0, index[1])(sm1 + index[0]), + c0_tmp[index[0]][index[1]]); + }); + } + } else { + u_matmul_generic + impl; + impl(a, b, c0, K); + } + } +}; + +template +struct u_matmul { + template + constexpr void operator()(const TA &a, const TB &b, TC &c0, + size_t K) noexcept { + TOutElem c0_tmp[M0Tile][N0Tile]; + ntt::apply(c0.shape(), [&](auto index) { + c0_tmp[index[0]][index[1]] = AccumulateC ? c0(index) : TOutElem{}; + }); + + for (size_t k1 = 0; k1 < K; k1++) { + auto a0 = + a.view(make_ranked_shape(0, k1), fixed_shape{}); + auto b0 = + b.view(make_ranked_shape(k1, 0), fixed_shape<1, N0Tile>{}); + for (size_t sk1 = 0; sk1 < TLhsElem::shape()[1]; sk1++) { + using TSubLhsElem = typename TLhsElem::element_type; + using TSubRhsElem = ntt::vector; + + TSubLhsElem a0_tmp[M0Tile]; + TSubRhsElem b0_tmp[N0Tile]; + + ntt::apply(fixed_shape{}, [&](auto index) { + a0_tmp[index[0]] = a0(index[0], 0)(sk1); + }); + ntt::apply(fixed_shape{}, [&](auto index) { + b0_tmp[index[0]] = b0(0, index[0])(sk1); + }); + + for (size_t n = 0; n < N0Tile; n++) { + for (size_t m = 0; m < M0Tile; m++) { + c0_tmp[m][n] = + ntt::mul_add(a0_tmp[m], b0_tmp[n], c0_tmp[m][n]); + } + } + } + } + + ntt::apply(c0.shape(), [&](auto index) { + ntt::store(c0(index), c0_tmp[index[0]][index[1]]); + }); + } +}; + +template +struct u_matmul { + template + constexpr void operator()(const TA &a, const TB &b, TC &c0, + size_t K) noexcept { + using TSubOutElem = ntt::vector; + using policy_t = + ntt::ukernels::u_matmul_policy; + constexpr auto m0_subtile = policy_t::m0_subtile; + + if constexpr (m0_subtile) { + TSubOutElem c0_tmp[m0_subtile][N0Tile]; + + for (size_t sm1 = 0; sm1 < TOutElem::shape()[0]; + sm1 += m0_subtile) { + ntt::apply(fixed_shape{}, [&](auto index) { + c0_tmp[index[0]][index[1]] = + AccumulateC ? c0(0, index[1])(sm1 + index[0]) + : TSubOutElem{}; + }); + + for (size_t k1 = 0; k1 < K; k1++) { + // Force compiler do not unroll the loop + size_t sk1_max = TLhsElem::shape()[1]; +#pragma GCC unroll 1 + for (size_t sk1 = 0; sk1 < sk1_max; sk1++) { + using TSubLhsElem = typename TLhsElem::element_type; + using TSubRhsElem = + ntt::vector; + + auto a0 = a.view(make_ranked_shape(0, k1), + fixed_shape{}); + auto b0 = b.view(make_ranked_shape(k1, 0), + fixed_shape<1, N0Tile>{}); + + TSubLhsElem a0_tmp[m0_subtile]; + TSubRhsElem b0_tmp[N0Tile]; + + ntt::apply(fixed_shape{}, [&](auto index) { + a0_tmp[index[0]] = a0(0, 0)(sm1 + index[0], sk1); + }); + ntt::apply(fixed_shape{}, [&](auto index) { + b0_tmp[index[0]] = b0(0, index[0])(sk1); + }); + + for (size_t n = 0; n < N0Tile; n++) { + for (size_t m = 0; m < m0_subtile; m++) { + auto &output = c0_tmp[m][n]; + output = + ntt::mul_add(a0_tmp[m], b0_tmp[n], output); + } + } + } + } + + ntt::apply(fixed_shape{}, [&](auto index) { + ntt::store(c0(0, index[1])(sm1 + index[0]), + c0_tmp[index[0]][index[1]]); + }); + } + } else { + u_matmul_generic + impl; + impl(a, b, c0, K); + } + } +}; +} // namespace ukernels + +template +constexpr void u_matmul(const TA &a, const TB &b, TC &c, size_t K) noexcept { + using TLhsElem = std::decay_t; + using TRhsElem = std::decay_t; + using TOutElem = std::decay_t; + ukernels::u_matmul + impl; + impl(a, b, c, K); +} +} // namespace nncase::ntt diff --git a/src/Native/include/nncase/ntt/ukernels/u_mul_add.h b/src/Native/include/nncase/ntt/ukernels/u_mul_add.h new file mode 100644 index 0000000000..c2b2b74635 --- /dev/null +++ b/src/Native/include/nncase/ntt/ukernels/u_mul_add.h @@ -0,0 +1,74 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../primitive_ops.h" + +namespace nncase::ntt { +namespace ukernels { +enum class mamtul_pack_kind { + unknown, + no_pack, + pack_m, + pack_k, + pack_n, + pack_mn, + pack_mk, + pack_kn, + pack_mkn, +}; +} // namespace ukernels + +template +void u_mul_add(const TLhsElem &lhs, const TRhsElem &rhs, TOutElem &output) { + // 1. 0D-packing + if constexpr (PackKind == ukernels::mamtul_pack_kind::no_pack) { + output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); + } + // 2. 1D-packing + // 2.1. pack M + else if constexpr (PackKind == ukernels::mamtul_pack_kind::pack_m) { + output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); + } + // 2.2. pack K + else if constexpr (PackKind == ukernels::mamtul_pack_kind::pack_k) { + auto value = ntt::inner_product(lhs, rhs); + output = AccC ? output + value : value; + } + // 2.3. pack N + else if constexpr (PackKind == ukernels::mamtul_pack_kind::pack_n) { + output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); + } + // 2.4. pack M & N + else if constexpr (PackKind == ukernels::mamtul_pack_kind::pack_mn || + PackKind == ukernels::mamtul_pack_kind::pack_kn) { + auto value = ntt::outer_product(lhs, rhs); + output = AccC ? output + value : value; + } + // 3.1. pack MK & K + else if constexpr (PackKind == ukernels::mamtul_pack_kind::pack_mk) { + for (size_t m = 0; m < lhs.shape()[0]; m++) { + auto value = ntt::inner_product(lhs(m), rhs); + output(m) = AccC ? output(m) + value : value; + } + } + // 3.2. pack MK & KN + else if constexpr (PackKind == ukernels::mamtul_pack_kind::pack_mkn) { + output = ntt::mma(lhs, rhs, output); + } else { + static_assert(sizeof(TLhsElem) == 0, "Unsupported packing."); + } +} +} // namespace nncase::ntt diff --git a/src/Native/include/nncase/ntt/ukernels/u_pack.h b/src/Native/include/nncase/ntt/ukernels/u_pack.h new file mode 100644 index 0000000000..899dcb9db9 --- /dev/null +++ b/src/Native/include/nncase/ntt/ukernels/u_pack.h @@ -0,0 +1,49 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace nncase::ntt { +namespace ukernels { +template +class u_pack { + public: + constexpr void operator()(const TIn *input, TOut *output) noexcept { + for (size_t j = 0; j < N; j++) { + for (size_t i = 0; i < M; i++) { + output[j](i) = input[i * MStrides + j]; + } + } + + if constexpr (M < TOut::shape_type::length()) { + for (size_t j = 0; j < N; j++) { + for (size_t i = M; i < TOut::shape_type::length(); i++) { + output[j](i) = (TIn)0; + } + } + } + } +}; +} // namespace ukernels + +template +constexpr void u_pack(const TIn *input, TOut *output) noexcept { + ukernels::u_pack, + std::decay_t> + impl; + impl(input, output); +} +} // namespace nncase::ntt diff --git a/src/Native/include/nncase/ntt/ukernels/u_reduce.h b/src/Native/include/nncase/ntt/ukernels/u_reduce.h new file mode 100644 index 0000000000..612d7520cc --- /dev/null +++ b/src/Native/include/nncase/ntt/ukernels/u_reduce.h @@ -0,0 +1,112 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../primitive_ops.h" + +namespace nncase::ntt { +namespace ukernels { +template struct reduce_to_binary_type; + +template <> struct reduce_to_binary_type { + template using type = ops::add; +}; + +template <> struct reduce_to_binary_type { + template using type = ops::min; +}; + +template <> struct reduce_to_binary_type { + template using type = ops::max; +}; + +template <> struct reduce_to_binary_type { + template using type = ops::add; +}; + +template <> struct reduce_to_binary_type { + template using type = ops::mul; +}; + +template struct u_reduce_policy { + static constexpr size_t unroll = 2; +}; + +template struct u_reduce { + public: + constexpr T operator()(const T *input, size_t input_stride, size_t count, + T init_value) noexcept { + using binary_op_t = + typename reduce_to_binary_type::template type; + using policy_t = u_reduce_policy; + constexpr auto unroll = policy_t::unroll; + + if (count / unroll) { + T temp[unroll]; +#if 1 + for (size_t i = 0; i < unroll; i++) { + temp[i] = *input; + input += input_stride; + count--; + } + + while (count / unroll) { + for (size_t i = 0; i < unroll; i++) { + temp[i] = binary_op_t()(temp[i], *input); + input += input_stride; + count--; + } + } + + init_value = binary_op_t()(init_value, tree_reduce(temp)); +#else + while (count / unroll) { + for (size_t i = 0; i < unroll; i++) { + temp[i] = *input; + input += input_stride; + count--; + } + init_value = + binary_op_t()(init_value, tree_reduce(temp)); + } +#endif + } + + for (size_t i = 0; i < count; i++) { + init_value = binary_op_t()(init_value, *input); + input += input_stride; + } + return init_value; + } + + template constexpr T tree_reduce(T *input) noexcept { + using binary_op_t = + typename reduce_to_binary_type::template type; + if constexpr (N == 2) { + return binary_op_t()(input[0], input[1]); + } else { + return binary_op_t()(tree_reduce(input), + tree_reduce(input + N / 2)); + } + } +}; +} // namespace ukernels + +template +constexpr T u_reduce(const T *input, size_t input_stride, size_t count, + T init_value) noexcept { + ukernels::u_reduce impl; + return impl(input, input_stride, count, init_value); +} +} // namespace nncase::ntt diff --git a/src/Native/include/nncase/ntt/unrool.h b/src/Native/include/nncase/ntt/unrool.h index be3c322879..7129a80a2a 100644 --- a/src/Native/include/nncase/ntt/unrool.h +++ b/src/Native/include/nncase/ntt/unrool.h @@ -41,27 +41,26 @@ template