From bfe4e7a335ad63e17d69b3094abfef3aae74d848 Mon Sep 17 00:00:00 2001 From: zhangyang2057 Date: Fri, 22 Nov 2024 10:10:28 +0800 Subject: [PATCH] Fix unpack 2D bug and add ctest case. --- ntt/include/nncase/ntt/kernels/unpack.h | 10 +++---- ntt/test/ctest/test_ntt_unpack.cpp | 37 +++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/ntt/include/nncase/ntt/kernels/unpack.h b/ntt/include/nncase/ntt/kernels/unpack.h index d06fcdebd..29d13dcb4 100644 --- a/ntt/include/nncase/ntt/kernels/unpack.h +++ b/ntt/include/nncase/ntt/kernels/unpack.h @@ -78,18 +78,16 @@ class unpack_impl, fixed_shape, OutShape, constexpr auto rank = TIn::shape_type::rank(); constexpr auto in_conti_dims = contiguous_dims( fixed_shape{}, fixed_strides{}); - constexpr auto low_axis = Axis1 < Axis2 ? Axis1 : Axis2; - constexpr auto high_axis = Axis1 < Axis2 ? Axis2 : Axis1; - if constexpr ((in_conti_dims == rank) && (high_axis == low_axis + 1) && - (high_axis != (rank - 1))) { + if constexpr ((in_conti_dims == rank) && (Axis2 == Axis1 + 1) && + (Axis2 != (rank - 1))) { auto pout = output.buffer().data(); auto count = input.shape().length(); constexpr auto in_strides = std::array{InStrides...}; constexpr auto v_shape = std::array{InElemDims...}; - ntt::u_unpack_2d_fixed( input, 1, pout, count); } else { diff --git a/ntt/test/ctest/test_ntt_unpack.cpp b/ntt/test/ctest/test_ntt_unpack.cpp index 3f7ef3f04..0251cee2e 100644 --- a/ntt/test/ctest/test_ntt_unpack.cpp +++ b/ntt/test/ctest/test_ntt_unpack.cpp @@ -389,6 +389,43 @@ TEST(UnpackTestFloat, fixed_shape_dim_N_W) { EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2)); } +TEST(UnpackTestFloat, fixed_shape_dim_C_N) { + constexpr size_t P = NTT_VLEN / (sizeof(float) * 8); + constexpr size_t N = P * 2; + constexpr size_t C = P * 2; + constexpr size_t H = P; + constexpr size_t W = P; + float min_input = -10.0f; + float max_input = 10.0f; + + // init + using tensor_type1 = ntt::tensor, + ntt::fixed_shape>; + alignas(32) tensor_type1 ntt_input; + NttTest::init_tensor(ntt_input, min_input, max_input); + + // ntt + using tensor_type2 = ntt::tensor>; + alignas(32) tensor_type2 ntt_output1; + ntt::unpack<1, 0>(ntt_input, ntt_output1); + + // ort + auto ort_input = NttTest::ntt2ort(ntt_input); + int64_t perms[] = {0, 5, 1, 4, 2, 3}; + auto tmp = ortki_Transpose(ort_input, perms, std::size(perms)); + int64_t data[] = {N, C, H, W}; + int64_t data_shape[] = {std::size(data)}; + auto ort_type = NttTest::primitive_type2ort_type(); + auto shape = make_tensor(reinterpret_cast(data), ort_type, + data_shape, std::size(data_shape)); + auto ort_output = ortki_Reshape(tmp, shape, 0); + + // compare + alignas(32) tensor_type2 ntt_output2; + NttTest::ort2ntt(ort_output, ntt_output2); + EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2)); +} + TEST(UnpackTestFloat, ranked_shape_dim_N) { constexpr size_t P = NTT_VLEN / (sizeof(float) * 8); constexpr size_t N = P * 2;