From dee30a4bd6c1b1db0d5d287cfbf321d0e3aeaa59 Mon Sep 17 00:00:00 2001 From: sunnycase Date: Thu, 12 Sep 2024 08:49:24 +0000 Subject: [PATCH] [ntt.x86] Optimize u_reduce --- .../include/nncase/ntt/arch/x86_64/ukernels.h | 81 ++++++++++++++++--- src/Native/test/benchmark_test/CMakeLists.txt | 1 - 2 files changed, 70 insertions(+), 12 deletions(-) diff --git a/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h b/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h index a8a74e74df..0a2becffd5 100644 --- a/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h +++ b/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h @@ -49,21 +49,80 @@ template struct u_reduce, true> { typename reduce_to_binary_type::template type, vector>; binary_op_t op; - if (count / 8) { + if (count / 4) { vector tmp[4]; - while (count / 8) { - for (size_t j = 0; j < 4; j++) { - tmp[j] = op(input[(j * 2) * input_stride], - input[(j * 2 + 1) * input_stride]); + for (size_t i = 0; i < 4; i++) { + tmp[i] = input[i * input_stride]; + } + input += input_stride * 4; + count -= 4; + while (count / 4) { + for (size_t i = 0; i < 4; i++) { + tmp[i] = op(tmp[i], input[i * input_stride]); } - input += input_stride * 8; - count -= 8; + input += input_stride * 4; + count -= 4; + } - tmp[0] = op(tmp[0], tmp[1]); - tmp[2] = op(tmp[2], tmp[3]); - tmp[0] = op(tmp[0], tmp[2]); - init_value = op(init_value, tmp[0]); + tmp[0] = op(tmp[0], tmp[1]); + tmp[2] = op(tmp[2], tmp[3]); + tmp[0] = op(tmp[0], tmp[2]); + init_value = op(init_value, tmp[0]); + } + + if (count / 2) { + vector tmp[2]; + for (size_t i = 0; i < 2; i++) { + tmp[i] = input[i * input_stride]; } + input += input_stride * 2; + count -= 2; + while (count / 2) { + for (size_t i = 0; i < 2; i++) { + tmp[i] = op(tmp[i], input[i * input_stride]); + } + input += input_stride * 2; + count -= 2; + } + + tmp[0] = op(tmp[0], tmp[1]); + init_value = op(init_value, tmp[0]); + } + + for (size_t i = 0; i < count; i++) { + init_value = op(init_value, *input); + input += input_stride; + } + return init_value; + } +}; + +template struct u_reduce { + public: + constexpr float operator()(const float *input, size_t input_stride, + size_t count, float init_value) noexcept { + using binary_op_t = + typename reduce_to_binary_type::template type; + binary_op_t op; + if (count / 4) { + float tmp[4]; + for (size_t i = 0; i < 4; i++) { + tmp[i] = input[i * input_stride]; + } + input += input_stride * 4; + count -= 4; + while (count / 4) { + for (size_t i = 0; i < 4; i++) { + tmp[i] = op(tmp[i], input[i * input_stride]); + } + input += input_stride * 4; + count -= 4; + } + + tmp[0] = op(tmp[0], tmp[1]); + tmp[2] = op(tmp[2], tmp[3]); + tmp[0] = op(tmp[0], tmp[2]); + init_value = op(init_value, tmp[0]); } for (size_t i = 0; i < count; i++) { diff --git a/src/Native/test/benchmark_test/CMakeLists.txt b/src/Native/test/benchmark_test/CMakeLists.txt index 1bbae71414..df509c8410 100644 --- a/src/Native/test/benchmark_test/CMakeLists.txt +++ b/src/Native/test/benchmark_test/CMakeLists.txt @@ -17,5 +17,4 @@ foreach(test_name ${TEST_NAMES}) endif() add_executable(${tname} ${tname}.cpp) target_link_libraries(${tname} PRIVATE nncaseruntime) - target_compile_options(${tname} PRIVATE -ffast-math) endforeach() \ No newline at end of file