diff --git a/ntt/include/nncase/ntt/arch/riscv64/ukernels.h b/ntt/include/nncase/ntt/arch/riscv64/ukernels.h index 58b2d2eb0..845707f79 100644 --- a/ntt/include/nncase/ntt/arch/riscv64/ukernels.h +++ b/ntt/include/nncase/ntt/arch/riscv64/ukernels.h @@ -77,9 +77,7 @@ SPECIALIZE_U_BINARY(floor_mod, 8) #undef SPECIALIZE_U_BINARY // clamp -template <> struct u_clamp_policy { - static constexpr size_t unroll = 8; -}; +template <> struct u_clamp_policy { static constexpr size_t unroll = 8; }; // reduce template struct u_reduce_policy { @@ -87,9 +85,7 @@ template struct u_reduce_policy { }; // cast -template <> struct u_cast_policy { - static constexpr size_t unroll = 8; -}; +template <> struct u_cast_policy { static constexpr size_t unroll = 8; }; // matmul template <> @@ -670,10 +666,10 @@ struct u_unpack_1d_fixed -class u_unpack_2d_fixed { +template +class u_unpack_2d_fixed { public: void operator()(const T1 &input, size_t in_stride, float *output, size_t count) noexcept { @@ -686,19 +682,19 @@ class u_unpack_2d_fixed); in_stride = in_stride + 1; - auto out_strides = high_axis_stride * sizeof(float); + auto out_strides = high_stride * sizeof(float); - while (count / high_axis_stride) { + while (count / high_stride) { auto out_ptr = output + in_offset + low_idx * low_extra + high_idx * high_extra; - auto out_end = out_ptr + high_axis_stride; + auto out_end = out_ptr + high_stride; while (out_ptr < out_end) { auto tmp = vl; size_t i_idx = 0; @@ -748,14 +744,14 @@ class u_unpack_2d_fixed> { }; template - requires(sizeof...(Axes) > 0 && - (std::get(std::array{ - Axes...}) == (TIn::rank() - 1))) -class u_pack2d, Axes...> { +requires(sizeof...(Axes) > 0 && + (std::get(std::array{ + Axes...}) == + (TIn::rank() - 1))) class u_pack2d, Axes...> { public: constexpr void operator()(const TIn &input, TOut &output) noexcept { using TVec = vector; diff --git a/ntt/include/nncase/ntt/kernels/unpack.h b/ntt/include/nncase/ntt/kernels/unpack.h index e6393418b..b687dc38f 100644 --- a/ntt/include/nncase/ntt/kernels/unpack.h +++ b/ntt/include/nncase/ntt/kernels/unpack.h @@ -94,15 +94,15 @@ class unpack_impl, fixed_shape, OutShape, } else { constexpr auto elem_rank = TVec::shape_type::rank(); constexpr fixed_shape domain{}; + constexpr auto axes = std::array{Axis1, Axis2}; apply(domain, [&](auto index) { auto in_index = slice_index(index); auto elem_index = slice_index(index, rank); auto out_index = slice_index(index); - out_index[low_axis] = - out_index[low_axis] * TVec::shape()[low_axis] + index[rank]; - out_index[high_axis] = - out_index[high_axis] * TVec::shape()[high_axis] + - index[rank]; + loop([&](auto i) { + out_index[axes[i]] = + out_index[axes[i]] * TVec::shape()[i] + index[rank + i]; + }); output(out_index) = input(in_index)(elem_index); }); } diff --git a/ntt/test/benchmark_test/benchmark_ntt.py b/ntt/test/benchmark_test/benchmark_ntt.py index 6287bdad5..c3e8de772 100644 --- a/ntt/test/benchmark_test/benchmark_ntt.py +++ b/ntt/test/benchmark_test/benchmark_ntt.py @@ -395,7 +395,7 @@ def __init__(self, target: str, bin_path: str): 'W': '4.3', 'NC': '6', 'CH': '6', - 'HW': '6', + 'HW': '4.3', }, } diff --git a/ntt/test/ctest/test_ntt_unpack.cpp b/ntt/test/ctest/test_ntt_unpack.cpp index f6986558f..3f7ef3f04 100644 --- a/ntt/test/ctest/test_ntt_unpack.cpp +++ b/ntt/test/ctest/test_ntt_unpack.cpp @@ -352,6 +352,43 @@ TEST(UnpackTestFloat, fixed_shape_dim_H_W) { EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2)); } +TEST(UnpackTestFloat, fixed_shape_dim_N_W) { + constexpr size_t P = NTT_VLEN / (sizeof(float) * 8); + constexpr size_t N = P * 2; + constexpr size_t C = P; + constexpr size_t H = P; + constexpr size_t W = P * 2; + float min_input = -10.0f; + float max_input = 10.0f; + + // init + using tensor_type1 = ntt::tensor, + ntt::fixed_shape>; + alignas(32) tensor_type1 ntt_input; + NttTest::init_tensor(ntt_input, min_input, max_input); + + // ntt + using tensor_type2 = ntt::tensor>; + alignas(32) tensor_type2 ntt_output1; + ntt::unpack<0, 3>(ntt_input, ntt_output1); + + // ort + auto ort_input = NttTest::ntt2ort(ntt_input); + int64_t perms[] = {0, 4, 1, 2, 3, 5}; + auto tmp = ortki_Transpose(ort_input, perms, std::size(perms)); + int64_t data[] = {N, C, H, W}; + int64_t data_shape[] = {std::size(data)}; + auto ort_type = NttTest::primitive_type2ort_type(); + auto shape = make_tensor(reinterpret_cast(data), ort_type, + data_shape, std::size(data_shape)); + auto ort_output = ortki_Reshape(tmp, shape, 0); + + // compare + alignas(32) tensor_type2 ntt_output2; + NttTest::ort2ntt(ort_output, ntt_output2); + EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2)); +} + TEST(UnpackTestFloat, ranked_shape_dim_N) { constexpr size_t P = NTT_VLEN / (sizeof(float) * 8); constexpr size_t N = P * 2; diff --git a/tools/clang-format.sh b/tools/clang-format.sh index fa5724807..8c99b5b51 100755 --- a/tools/clang-format.sh +++ b/tools/clang-format.sh @@ -32,6 +32,5 @@ find "${ROOT_DIR}/tests" \ "${ROOT_DIR}/modules" \ "${ROOT_DIR}/python" \ "${ROOT_DIR}/targets" \ - "${ROOT_DIR}/ntt" \ \( -name "*.h" -o -name "*.c" -o -name "*.cc" -o -name "*.cxx" -o -name "*.cpp" -o -name "*.hpp" -o -name "*.cppm" \) -and -not -wholename "*/.*" | \ xargs ${CLANG_FORMAT_LLVM_INSTALL_DIR}/bin/clang-format -i -style=file