Skip to content

Commit

Permalink
Fix unpack<C, N> 2D bug and add ctest case.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyang2057 committed Nov 22, 2024
1 parent 0b886aa commit bfe4e7a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
10 changes: 4 additions & 6 deletions ntt/include/nncase/ntt/kernels/unpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,16 @@ class unpack_impl<fixed_shape<InDims...>, fixed_shape<InElemDims...>, OutShape,
constexpr auto rank = TIn::shape_type::rank();
constexpr auto in_conti_dims = contiguous_dims(
fixed_shape<InDims...>{}, fixed_strides<InStrides...>{});
constexpr auto low_axis = Axis1 < Axis2 ? Axis1 : Axis2;
constexpr auto high_axis = Axis1 < Axis2 ? Axis2 : Axis1;
if constexpr ((in_conti_dims == rank) && (high_axis == low_axis + 1) &&
(high_axis != (rank - 1))) {
if constexpr ((in_conti_dims == rank) && (Axis2 == Axis1 + 1) &&
(Axis2 != (rank - 1))) {
auto pout = output.buffer().data();
auto count = input.shape().length();
constexpr auto in_strides =
std::array<size_t, sizeof...(InStrides)>{InStrides...};
constexpr auto v_shape =
std::array<size_t, sizeof...(InElemDims)>{InElemDims...};
ntt::u_unpack_2d_fixed<in_strides[low_axis], v_shape[0],
in_strides[high_axis], v_shape[1], TIn,
ntt::u_unpack_2d_fixed<in_strides[Axis1], v_shape[0],
in_strides[Axis2], v_shape[1], TIn,
typename TOut::element_type, Axis1, Axis2>(
input, 1, pout, count);
} else {
Expand Down
37 changes: 37 additions & 0 deletions ntt/test/ctest/test_ntt_unpack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,43 @@ TEST(UnpackTestFloat, fixed_shape_dim_N_W) {
EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2));
}

TEST(UnpackTestFloat, fixed_shape_dim_C_N) {
constexpr size_t P = NTT_VLEN / (sizeof(float) * 8);
constexpr size_t N = P * 2;
constexpr size_t C = P * 2;
constexpr size_t H = P;
constexpr size_t W = P;
float min_input = -10.0f;
float max_input = 10.0f;

// init
using tensor_type1 = ntt::tensor<ntt::vector<float, P, P>,
ntt::fixed_shape<N / P, C / P, H, W>>;
alignas(32) tensor_type1 ntt_input;
NttTest::init_tensor(ntt_input, min_input, max_input);

// ntt
using tensor_type2 = ntt::tensor<float, ntt::fixed_shape<N, C, H, W>>;
alignas(32) tensor_type2 ntt_output1;
ntt::unpack<1, 0>(ntt_input, ntt_output1);

// ort
auto ort_input = NttTest::ntt2ort(ntt_input);
int64_t perms[] = {0, 5, 1, 4, 2, 3};
auto tmp = ortki_Transpose(ort_input, perms, std::size(perms));
int64_t data[] = {N, C, H, W};
int64_t data_shape[] = {std::size(data)};
auto ort_type = NttTest::primitive_type2ort_type<int64_t>();
auto shape = make_tensor(reinterpret_cast<void *>(data), ort_type,
data_shape, std::size(data_shape));
auto ort_output = ortki_Reshape(tmp, shape, 0);

// compare
alignas(32) tensor_type2 ntt_output2;
NttTest::ort2ntt(ort_output, ntt_output2);
EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2));
}

TEST(UnpackTestFloat, ranked_shape_dim_N) {
constexpr size_t P = NTT_VLEN / (sizeof(float) * 8);
constexpr size_t N = P * 2;
Expand Down

0 comments on commit bfe4e7a

Please sign in to comment.