Skip to content

Commit

Permalink
[ntt.x86] Optimize u_reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Sep 12, 2024
1 parent 90632e1 commit dee30a4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 12 deletions.
81 changes: 70 additions & 11 deletions src/Native/include/nncase/ntt/arch/x86_64/ukernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,80 @@ template <reduce_op Op> struct u_reduce<Op, vector<float, 8>, true> {
typename reduce_to_binary_type<Op>::template type<vector<float, 8>,
vector<float, 8>>;
binary_op_t op;
if (count / 8) {
if (count / 4) {
vector<float, 8> tmp[4];
while (count / 8) {
for (size_t j = 0; j < 4; j++) {
tmp[j] = op(input[(j * 2) * input_stride],
input[(j * 2 + 1) * input_stride]);
for (size_t i = 0; i < 4; i++) {
tmp[i] = input[i * input_stride];
}
input += input_stride * 4;
count -= 4;
while (count / 4) {
for (size_t i = 0; i < 4; i++) {
tmp[i] = op(tmp[i], input[i * input_stride]);
}
input += input_stride * 8;
count -= 8;
input += input_stride * 4;
count -= 4;
}

tmp[0] = op(tmp[0], tmp[1]);
tmp[2] = op(tmp[2], tmp[3]);
tmp[0] = op(tmp[0], tmp[2]);
init_value = op(init_value, tmp[0]);
tmp[0] = op(tmp[0], tmp[1]);
tmp[2] = op(tmp[2], tmp[3]);
tmp[0] = op(tmp[0], tmp[2]);
init_value = op(init_value, tmp[0]);
}

if (count / 2) {
vector<float, 8> tmp[2];
for (size_t i = 0; i < 2; i++) {
tmp[i] = input[i * input_stride];
}
input += input_stride * 2;
count -= 2;
while (count / 2) {
for (size_t i = 0; i < 2; i++) {
tmp[i] = op(tmp[i], input[i * input_stride]);
}
input += input_stride * 2;
count -= 2;
}

tmp[0] = op(tmp[0], tmp[1]);
init_value = op(init_value, tmp[0]);
}

for (size_t i = 0; i < count; i++) {
init_value = op(init_value, *input);
input += input_stride;
}
return init_value;
}
};

template <reduce_op Op> struct u_reduce<Op, float, true> {
public:
constexpr float operator()(const float *input, size_t input_stride,
size_t count, float init_value) noexcept {
using binary_op_t =
typename reduce_to_binary_type<Op>::template type<float, float>;
binary_op_t op;
if (count / 4) {
float tmp[4];
for (size_t i = 0; i < 4; i++) {
tmp[i] = input[i * input_stride];
}
input += input_stride * 4;
count -= 4;
while (count / 4) {
for (size_t i = 0; i < 4; i++) {
tmp[i] = op(tmp[i], input[i * input_stride]);
}
input += input_stride * 4;
count -= 4;
}

tmp[0] = op(tmp[0], tmp[1]);
tmp[2] = op(tmp[2], tmp[3]);
tmp[0] = op(tmp[0], tmp[2]);
init_value = op(init_value, tmp[0]);
}

for (size_t i = 0; i < count; i++) {
Expand Down
1 change: 0 additions & 1 deletion src/Native/test/benchmark_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ foreach(test_name ${TEST_NAMES})
endif()
add_executable(${tname} ${tname}.cpp)
target_link_libraries(${tname} PRIVATE nncaseruntime)
target_compile_options(${tname} PRIVATE -ffast-math)
endforeach()

0 comments on commit dee30a4

Please sign in to comment.