Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ntt benchmark roofline 3 #1241

Merged
merged 79 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
0d8e4f0
update for pack M&N x86
Sep 9, 2024
145bb31
revert for some performance fallback
Sep 10, 2024
2e01437
add funroll-loops for gcc
Sep 10, 2024
d67cc5e
shutdown for ci
Sep 10, 2024
1933c5d
revise bug for unroll
Sep 10, 2024
f743675
add softmax benchmark
Sep 10, 2024
2ee1e9f
add ctest for softmax
Sep 10, 2024
1a767ec
Fix ctest failure for softmax.
zhangyang2057 Sep 11, 2024
fc90313
opt for x86 softmax
Sep 11, 2024
10896af
revise benchmark for softmax
Sep 11, 2024
5481454
Add rvv optimization of tanh with max_ulp_error = 2.
zhangyang2057 Sep 12, 2024
80d2e23
add tanh for x86 ulp version
Sep 12, 2024
643485c
remove usless headfile
Sep 12, 2024
7f3bbff
Apply code-format changes
uranus0515 Sep 12, 2024
3371dd2
Optimize mul_add for rvv(performance boost 15% ~ 32%)
zhangyang2057 Sep 13, 2024
5483778
Merge branch 'dev/3.0' into feature/ntt_benchmark_roofline_3
zhangyang2057 Sep 14, 2024
6425359
Remove ntt softmax and fix reduce conflict of x86_64.
zhangyang2057 Sep 14, 2024
52c1843
change roofline for reduce x86
Sep 18, 2024
5b75d2e
Optimize matmul for rvv and update roofline.
zhangyang2057 Sep 18, 2024
e3d0d59
Update reduce roofline for rvv.
zhangyang2057 Sep 20, 2024
84bd918
Update Max_reduceMN_PackN roofline.
zhangyang2057 Sep 20, 2024
d5e4eaf
Add ratio for roofline / actual.
zhangyang2057 Sep 20, 2024
844a625
update tanh Roofline
Sep 23, 2024
c7e15f5
Specialize max/min for float and update roofline for reduce no_pack.
zhangyang2057 Sep 23, 2024
2fbbaad
problem about x86 roofline
Sep 25, 2024
96499f3
[NTT] Add ukernel for matmul
sunnycase Sep 25, 2024
e7ac09e
Apply code-format changes
sunnycase Sep 25, 2024
61e6bf6
Fix build
sunnycase Sep 25, 2024
724f355
Merge branch 'feature/optimize_matmul' into feature/ntt_benchmark_roo…
Sep 25, 2024
60f4d3a
change some reality for x86
Sep 25, 2024
2ff5d65
fallback roofline
Sep 25, 2024
677106c
change sequence for test
Sep 25, 2024
21c46ce
add warmup for unary
Sep 25, 2024
cad1f36
add primitive size auto test
Sep 26, 2024
15d136e
revise bug in daily test
Sep 26, 2024
29d065b
revise bug in daily test
Sep 26, 2024
1c5c986
avoid bug for temp
Sep 26, 2024
bfe66d2
total fallback
Sep 26, 2024
404ff91
add info for primitive size
Sep 26, 2024
6e4ee24
remove tile infor
Sep 26, 2024
e24d785
change the way
Sep 26, 2024
7c39215
Add tensor.squeeze
sunnycase Sep 26, 2024
981a2e1
remove Primitive infor
Sep 26, 2024
7e301fd
Apply code-format changes
sunnycase Sep 26, 2024
ade390b
temp change test
Sep 26, 2024
62709c4
add table name
Sep 26, 2024
f957613
merge two table
Sep 26, 2024
ea4568e
remove typo
Sep 26, 2024
2e7107f
typo test
Sep 26, 2024
9172aac
change back for daily test
Sep 26, 2024
19a84dc
test for table
Sep 27, 2024
cc08589
change back tor test over
Sep 27, 2024
52b0e8d
change for primitive size
Sep 27, 2024
c3fa5ec
Support odd matmul
sunnycase Sep 27, 2024
32f6eda
Merge branch 'feature/optimize_matmul' into feature/ntt_benchmark_roo…
Sep 27, 2024
f5a814e
Fix build
sunnycase Sep 27, 2024
ec704d1
Apply code-format changes
sunnycase Sep 27, 2024
e250d13
Optimze erf for rvv.
zhangyang2057 Sep 27, 2024
fcc581b
Merge branch 'feature/ntt_benchmark_roofline_3' of https://github.com…
Sep 27, 2024
587beae
Fix build
sunnycase Sep 27, 2024
206b5c5
Add markdown for ntt mamtul.
zhangyang2057 Sep 27, 2024
985207d
Merge branch 'feature/optimize_matmul' into feature/ntt_benchmark_roo…
Sep 27, 2024
57bc8bb
Merge branch 'feature/ntt_benchmark_roofline_3' of https://github.com…
Sep 27, 2024
f750353
Add u_matmul policy for rvv
sunnycase Sep 27, 2024
0c25fc2
add erf ulp version
Sep 27, 2024
6ad71b5
fix typo
Sep 27, 2024
decbc75
Refactor benchmark ntt py to support both ntt and ntt_matmul.
zhangyang2057 Sep 29, 2024
81f3318
Apply code-format changes
zhangyang2057 Sep 29, 2024
4f4a95c
Add ntt.store, optimize u_matmul for RVV
sunnycase Sep 29, 2024
e0e1890
Fix pack MKN for RVV
sunnycase Sep 29, 2024
6556d01
Fix macos build and show gflops with floating point.
zhangyang2057 Sep 29, 2024
a392f69
revise typo
Sep 30, 2024
b97e090
Force compiler do not unroll k loops
sunnycase Sep 30, 2024
0faa9c3
Use pragma unroll 1 instead of volatile
sunnycase Sep 30, 2024
a553d9a
set performance for cpu0
Sep 30, 2024
bfcbeb1
Merge branch 'feature/ntt_benchmark_roofline_3' of https://github.com…
Sep 30, 2024
6eb378e
Merge branch 'dev/3.0' into feature/ntt_benchmark_roofline_3
Sep 30, 2024
2c5dfaa
Apply code-format changes
uranus0515 Sep 30, 2024
7c0bf99
temp fallback for ci test
Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 248 additions & 30 deletions src/Native/include/nncase/ntt/arch/riscv64/primitive_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
*/
#pragma once
#include "../../primitive_ops.h"
#include "nncase/ntt/arch/riscv64/arch_types.h"
#include "nncase/ntt/vector.h"
#include "rvv_mathfun.h"

#ifdef __riscv_vector
Expand All @@ -29,6 +31,15 @@ namespace nncase::ntt::ops {
kernel(1, 32) kernel(2, 16) kernel(4, 8) kernel(8, 4)
#endif

template <>
struct store<ntt::vector<float, NTT_VLEN / 32>,
ntt::vector<float, NTT_VLEN / 32>> {
void operator()(ntt::vector<float, NTT_VLEN / 32> &dest,
const ntt::vector<float, NTT_VLEN / 32> &v) const noexcept {
__riscv_vse32_v_f32m1((float *)&dest, v, NTT_VLEN / 32);
}
};

#define RVV_UNARY_OP(op, dtype, vl, kernel) \
template <> struct op<ntt::vector<dtype, vl>> { \
ntt::vector<dtype, vl> \
Expand Down Expand Up @@ -610,6 +621,16 @@ REGISTER_RVV_UNARY_OP(square, float, square_float32)
REGISTER_RVV_KERNEL(TANH_FLOAT32)
REGISTER_RVV_UNARY_OP(tanh, float, tanh_float32)

// erf
#define ERF_FLOAT32(lmul, mlen) \
inline vfloat32m##lmul##_t erf_float32(const vfloat32m##lmul##_t &v, \
const size_t vl) { \
return erf_ps(v, vl); \
}

REGISTER_RVV_KERNEL(ERF_FLOAT32)
REGISTER_RVV_UNARY_OP(erf, float, erf_float32)

// binary
#define RVV_BINARY_OP(op, dtype, vl, kernel) \
template <> struct op<ntt::vector<dtype, vl>, ntt::vector<dtype, vl>> { \
Expand Down Expand Up @@ -761,6 +782,16 @@ REGISTER_RVV_KERNEL(MOD_FLOAT32)
REGISTER_RVV_BINARY_OP(mod, float, mod_float32)

// min
template <> struct min<float, float> {
auto operator()(const float &s1, const float &s2) const noexcept {
float ret;
__asm("fmin.s %[ret], %[s1], %[s2];"
: [ret] "=f"(ret)
: [s1] "f"(s1), [s2] "f"(s2));
return ret;
}
};

#define MIN_FLOAT32(lmul, mlen) \
inline vfloat32m##lmul##_t min_float32(const vfloat32m##lmul##_t &v1, \
const vfloat32m##lmul##_t &v2, \
Expand All @@ -782,6 +813,16 @@ REGISTER_RVV_KERNEL(MIN_FLOAT32)
REGISTER_RVV_BINARY_OP(min, float, min_float32)

// max
template <> struct max<float, float> {
auto operator()(const float &s1, const float &s2) const noexcept {
float ret;
__asm("fmax.s %[ret], %[s1], %[s2];"
: [ret] "=f"(ret)
: [s1] "f"(s1), [s2] "f"(s2));
return ret;
}
};

#define MAX_FLOAT32(lmul, mlen) \
inline vfloat32m##lmul##_t max_float32(const vfloat32m##lmul##_t &v1, \
const vfloat32m##lmul##_t &v2, \
Expand Down Expand Up @@ -969,6 +1010,7 @@ REGISTER_RVV_KERNEL(INNER_PRODUCT_FLOAT32)
REGISTER_RVV_INNER_PRODUCT_OP(float, inner_product_float32)

// register mul_add kernel
#if 0
#define MUL_ADD_FLOAT32(lmul, mlen) \
inline vfloat32m##lmul##_t mul_add_float32( \
const vfloat32m##lmul##_t &v1, const vfloat32m##lmul##_t &v2, \
Expand All @@ -987,6 +1029,26 @@ REGISTER_RVV_INNER_PRODUCT_OP(float, inner_product_float32)
const vfloat32m##lmul##_t &v3, const size_t vl) { \
return __riscv_vfmadd_vf_f32m##lmul(v2, s1, v3, vl); \
}
#else
#define MUL_ADD_FLOAT32(lmul, mlen) \
inline vfloat32m##lmul##_t mul_add_float32( \
const vfloat32m##lmul##_t &v1, const vfloat32m##lmul##_t &v2, \
const vfloat32m##lmul##_t &v3, const size_t vl) { \
return __riscv_vfmacc_vv_f32m##lmul(v3, v1, v2, vl); \
} \
\
inline vfloat32m##lmul##_t mul_add_float32( \
const vfloat32m##lmul##_t &v1, const float &s2, \
const vfloat32m##lmul##_t &v3, const size_t vl) { \
return __riscv_vfmacc_vf_f32m##lmul(v3, s2, v1, vl); \
} \
\
inline vfloat32m##lmul##_t mul_add_float32( \
const float &s1, const vfloat32m##lmul##_t &v2, \
const vfloat32m##lmul##_t &v3, const size_t vl) { \
return __riscv_vfmacc_vf_f32m##lmul(v3, s1, v2, vl); \
}
#endif

REGISTER_RVV_KERNEL(MUL_ADD_FLOAT32)

Expand Down Expand Up @@ -1029,7 +1091,6 @@ REGISTER_RVV_KERNEL(MUL_ADD_FLOAT32)

REGISTER_RVV_MUL_ADD_OP(float, mul_add_float32)

#if 1
template <bool AccC>
struct mma<AccC, ntt::vector<float, 1, 4>, ntt::vector<float, 4, 4>,
ntt::vector<float, 1, 4>> {
Expand All @@ -1038,11 +1099,67 @@ struct mma<AccC, ntt::vector<float, 1, 4>, ntt::vector<float, 4, 4>,
const ntt::vector<float, 4, 4> &rhs,
const ntt::vector<float, 1, 4> &v3) const noexcept {
auto output = v3;
for (size_t k = 0; k < 4; k++) {
output(0) = (k != 0 || AccC)
? ntt::mul_add(lhs(0, k), rhs(k), output(0))
: ntt::mul(lhs(0, k), rhs(k));
}
auto t0 = AccC ? ntt::mul_add(lhs(0, 0), rhs(0), output(0))
: ntt::mul(lhs(0, 0), rhs(0));
auto t1 = ntt::mul(lhs(0, 1), rhs(1));
t0 = ntt::mul_add(lhs(0, 2), rhs(2), t0);
t1 = ntt::mul_add(lhs(0, 3), rhs(3), t1);
output(0) = ntt::add(t0, t1);
return output;
}
};

template <bool AccC>
struct mma<AccC, ntt::vector<float, 1, 32>, ntt::vector<float, 32, 32>,
ntt::vector<float, 1, 32>> {
ntt::vector<float, 1, 32>
operator()(const ntt::vector<float, 1, 32> &lhs,
const ntt::vector<float, 32, 32> &rhs,
const ntt::vector<float, 1, 32> &v3) const noexcept {
auto output = v3;

auto t0 = AccC ? ntt::mul_add(lhs(0, 0), rhs(0), output(0))
: ntt::mul(lhs(0, 0), rhs(0));
auto t1 = ntt::mul(lhs(0, 1), rhs(1));
t0 = ntt::mul_add(lhs(0, 2), rhs(2), t0);
t1 = ntt::mul_add(lhs(0, 3), rhs(3), t1);

t0 = ntt::mul_add(lhs(0, 4), rhs(4), t0);
t1 = ntt::mul_add(lhs(0, 5), rhs(5), t1);
t0 = ntt::mul_add(lhs(0, 6), rhs(6), t0);
t1 = ntt::mul_add(lhs(0, 7), rhs(7), t1);

t0 = ntt::mul_add(lhs(0, 8), rhs(8), t0);
t1 = ntt::mul_add(lhs(0, 9), rhs(9), t1);
t0 = ntt::mul_add(lhs(0, 10), rhs(10), t0);
t1 = ntt::mul_add(lhs(0, 11), rhs(11), t1);

t0 = ntt::mul_add(lhs(0, 12), rhs(12), t0);
t1 = ntt::mul_add(lhs(0, 13), rhs(13), t1);
t0 = ntt::mul_add(lhs(0, 14), rhs(14), t0);
t1 = ntt::mul_add(lhs(0, 15), rhs(15), t1);

t0 = ntt::mul_add(lhs(0, 16), rhs(16), t0);
t1 = ntt::mul_add(lhs(0, 17), rhs(17), t1);
t0 = ntt::mul_add(lhs(0, 18), rhs(18), t0);
t1 = ntt::mul_add(lhs(0, 19), rhs(19), t1);

t0 = ntt::mul_add(lhs(0, 20), rhs(20), t0);
t1 = ntt::mul_add(lhs(0, 21), rhs(21), t1);
t0 = ntt::mul_add(lhs(0, 22), rhs(22), t0);
t1 = ntt::mul_add(lhs(0, 23), rhs(23), t1);

t0 = ntt::mul_add(lhs(0, 24), rhs(24), t0);
t1 = ntt::mul_add(lhs(0, 25), rhs(25), t1);
t0 = ntt::mul_add(lhs(0, 26), rhs(26), t0);
t1 = ntt::mul_add(lhs(0, 27), rhs(27), t1);

t0 = ntt::mul_add(lhs(0, 28), rhs(28), t0);
t1 = ntt::mul_add(lhs(0, 29), rhs(29), t1);
t0 = ntt::mul_add(lhs(0, 30), rhs(30), t0);
t1 = ntt::mul_add(lhs(0, 31), rhs(31), t1);

output(0) = ntt::add(t0, t1);
return output;
}
};
Expand All @@ -1055,33 +1172,134 @@ struct mma<AccC, ntt::vector<float, 4, 4>, ntt::vector<float, 4, 4>,
const ntt::vector<float, 4, 4> &rhs,
const ntt::vector<float, 4, 4> &v3) const noexcept {
auto output = v3;
for (size_t k = 0; k < 4; k++) {
output(0) = (k != 0 || AccC)
? ntt::mul_add(lhs(0, k), rhs(k), output(0))
: ntt::mul(lhs(0, k), rhs(k));
}

for (size_t k = 0; k < 4; k++) {
output(1) = (k != 0 || AccC)
? ntt::mul_add(lhs(1, k), rhs(k), output(1))
: ntt::mul(lhs(1, k), rhs(k));
}

for (size_t k = 0; k < 4; k++) {
output(2) = (k != 0 || AccC)
? ntt::mul_add(lhs(2, k), rhs(k), output(2))
: ntt::mul(lhs(2, k), rhs(k));
}

for (size_t k = 0; k < 4; k++) {
output(3) = (k != 0 || AccC)
? ntt::mul_add(lhs(3, k), rhs(k), output(3))
: ntt::mul(lhs(3, k), rhs(k));
}
ntt::fixed_tensor_alike_t<ntt::vector<float, 4>, 1, 4> lhs_2d[4]{
{{lhs(0)}},
{{lhs(1)}},
{{lhs(2)}},
{{lhs(3)}},
};
ntt::fixed_tensor_alike_t<ntt::vector<float, 4>, 1, 4> output_2d[4]{
{{v3(0)}},
{{v3(1)}},
{{v3(2)}},
{{v3(3)}},
};

output_2d[0] = ntt::mma<AccC>(lhs_2d[0], rhs, output_2d[0]);
output_2d[1] = ntt::mma<AccC>(lhs_2d[1], rhs, output_2d[1]);
output_2d[2] = ntt::mma<AccC>(lhs_2d[2], rhs, output_2d[2]);
output_2d[3] = ntt::mma<AccC>(lhs_2d[3], rhs, output_2d[3]);

output(0) = output_2d[0](0);
output(1) = output_2d[1](0);
output(2) = output_2d[2](0);
output(3) = output_2d[3](0);

return output;
}
};

template <bool AccC>
struct mma<AccC, ntt::vector<float, 32, 32>, ntt::vector<float, 32, 32>,
ntt::vector<float, 32, 32>> {
ntt::vector<float, 32, 32>
operator()(const ntt::vector<float, 32, 32> &lhs,
const ntt::vector<float, 32, 32> &rhs,
const ntt::vector<float, 32, 32> &v3) const noexcept {
auto output = v3;
ntt::fixed_tensor_alike_t<ntt::vector<float, 32>, 1, 32> lhs_2d[]{
{{lhs(0)}}, {{lhs(1)}}, {{lhs(2)}}, {{lhs(3)}}, {{lhs(4)}},
{{lhs(5)}}, {{lhs(6)}}, {{lhs(7)}}, {{lhs(8)}}, {{lhs(9)}},
{{lhs(10)}}, {{lhs(11)}}, {{lhs(12)}}, {{lhs(13)}}, {{lhs(14)}},
{{lhs(15)}}, {{lhs(16)}}, {{lhs(17)}}, {{lhs(18)}}, {{lhs(19)}},
{{lhs(20)}}, {{lhs(21)}}, {{lhs(22)}}, {{lhs(23)}}, {{lhs(24)}},
{{lhs(25)}}, {{lhs(26)}}, {{lhs(27)}}, {{lhs(28)}}, {{lhs(29)}},
{{lhs(30)}}, {{lhs(31)}}};

ntt::fixed_tensor_alike_t<ntt::vector<float, 32>, 1, 32> output_2d[]{
{{v3(0)}}, {{v3(1)}}, {{v3(2)}}, {{v3(3)}}, {{v3(4)}},
{{v3(5)}}, {{v3(6)}}, {{v3(7)}}, {{v3(8)}}, {{v3(9)}},
{{v3(10)}}, {{v3(11)}}, {{v3(12)}}, {{v3(13)}}, {{v3(14)}},
{{v3(15)}}, {{v3(16)}}, {{v3(17)}}, {{v3(18)}}, {{v3(19)}},
{{v3(20)}}, {{v3(21)}}, {{v3(22)}}, {{v3(23)}}, {{v3(24)}},
{{v3(25)}}, {{v3(26)}}, {{v3(27)}}, {{v3(28)}}, {{v3(29)}},
{{v3(30)}}, {{v3(31)}}};

output_2d[0] = ntt::mma<AccC>(lhs_2d[0], rhs, output_2d[0]);
output_2d[1] = ntt::mma<AccC>(lhs_2d[1], rhs, output_2d[1]);
output_2d[2] = ntt::mma<AccC>(lhs_2d[2], rhs, output_2d[2]);
output_2d[3] = ntt::mma<AccC>(lhs_2d[3], rhs, output_2d[3]);
output_2d[4] = ntt::mma<AccC>(lhs_2d[4], rhs, output_2d[4]);
output_2d[5] = ntt::mma<AccC>(lhs_2d[5], rhs, output_2d[5]);
output_2d[6] = ntt::mma<AccC>(lhs_2d[6], rhs, output_2d[6]);
output_2d[7] = ntt::mma<AccC>(lhs_2d[7], rhs, output_2d[7]);

output_2d[8] = ntt::mma<AccC>(lhs_2d[8], rhs, output_2d[8]);
output_2d[9] = ntt::mma<AccC>(lhs_2d[9], rhs, output_2d[9]);
output_2d[10] = ntt::mma<AccC>(lhs_2d[10], rhs, output_2d[10]);
output_2d[11] = ntt::mma<AccC>(lhs_2d[11], rhs, output_2d[11]);
output_2d[12] = ntt::mma<AccC>(lhs_2d[12], rhs, output_2d[12]);
output_2d[13] = ntt::mma<AccC>(lhs_2d[13], rhs, output_2d[13]);
output_2d[14] = ntt::mma<AccC>(lhs_2d[14], rhs, output_2d[14]);
output_2d[15] = ntt::mma<AccC>(lhs_2d[15], rhs, output_2d[15]);

output_2d[16] = ntt::mma<AccC>(lhs_2d[16], rhs, output_2d[16]);
output_2d[17] = ntt::mma<AccC>(lhs_2d[17], rhs, output_2d[17]);
output_2d[18] = ntt::mma<AccC>(lhs_2d[18], rhs, output_2d[18]);
output_2d[19] = ntt::mma<AccC>(lhs_2d[19], rhs, output_2d[19]);
output_2d[20] = ntt::mma<AccC>(lhs_2d[20], rhs, output_2d[20]);
output_2d[21] = ntt::mma<AccC>(lhs_2d[21], rhs, output_2d[21]);
output_2d[22] = ntt::mma<AccC>(lhs_2d[22], rhs, output_2d[22]);
output_2d[23] = ntt::mma<AccC>(lhs_2d[23], rhs, output_2d[23]);

output_2d[24] = ntt::mma<AccC>(lhs_2d[24], rhs, output_2d[24]);
output_2d[25] = ntt::mma<AccC>(lhs_2d[25], rhs, output_2d[25]);
output_2d[26] = ntt::mma<AccC>(lhs_2d[26], rhs, output_2d[26]);
output_2d[27] = ntt::mma<AccC>(lhs_2d[27], rhs, output_2d[27]);
output_2d[28] = ntt::mma<AccC>(lhs_2d[28], rhs, output_2d[28]);
output_2d[29] = ntt::mma<AccC>(lhs_2d[29], rhs, output_2d[29]);
output_2d[30] = ntt::mma<AccC>(lhs_2d[30], rhs, output_2d[30]);
output_2d[31] = ntt::mma<AccC>(lhs_2d[31], rhs, output_2d[31]);

output(0) = output_2d[0](0);
output(1) = output_2d[1](0);
output(2) = output_2d[2](0);
output(3) = output_2d[3](0);
output(4) = output_2d[4](0);
output(5) = output_2d[5](0);
output(6) = output_2d[6](0);
output(7) = output_2d[7](0);

output(8) = output_2d[8](0);
output(9) = output_2d[9](0);
output(10) = output_2d[10](0);
output(11) = output_2d[11](0);
output(12) = output_2d[12](0);
output(13) = output_2d[13](0);
output(14) = output_2d[14](0);
output(15) = output_2d[15](0);

output(16) = output_2d[16](0);
output(17) = output_2d[17](0);
output(18) = output_2d[18](0);
output(19) = output_2d[19](0);
output(20) = output_2d[20](0);
output(21) = output_2d[21](0);
output(22) = output_2d[22](0);
output(23) = output_2d[23](0);

output(24) = output_2d[24](0);
output(25) = output_2d[25](0);
output(26) = output_2d[26](0);
output(27) = output_2d[27](0);
output(28) = output_2d[28](0);
output(29) = output_2d[29](0);
output(30) = output_2d[30](0);
output(31) = output_2d[31](0);

return output;
}
};
#endif

// register reduce_sum kernel
#define REDUCE_ADD_FLOAT32(lmul, mlen) \
Expand Down
Loading
Loading