Skip to content

Commit

Permalink
Optimize unpack 2D.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyang2057 committed Nov 22, 2024
1 parent d399281 commit 0b886aa
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 4 deletions.
126 changes: 126 additions & 0 deletions ntt/include/nncase/ntt/arch/riscv64/ukernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ struct u_unpack_1d_fixed<axis_stride, NTT_VLEN / 32, T1, float, true,
}
};

#if 0
template <size_t low_stride, size_t high_stride, class T1, size_t PackAxis1,
size_t PackAxis2>
class u_unpack_2d_fixed<low_stride, NTT_VLEN / 32, high_stride, NTT_VLEN / 32,
Expand Down Expand Up @@ -760,6 +761,131 @@ class u_unpack_2d_fixed<low_stride, NTT_VLEN / 32, high_stride, NTT_VLEN / 32,
}
}
};
#else
template <size_t low_stride, size_t high_stride, class T1, size_t PackAxis1,
size_t PackAxis2>
class u_unpack_2d_fixed<low_stride, NTT_VLEN / 32, high_stride, NTT_VLEN / 32,
T1, float, true, PackAxis1, PackAxis2> {
public:
void operator()(const T1 &input, size_t in_stride, float *output,
size_t count) noexcept {
constexpr size_t vl = NTT_VLEN / 32;
auto in_ptr = input.buffer().data();
// using policy_t =
// u_unpack_policy<vector<float, 4, 4>, float,
// true>;
// constexpr auto unroll = policy_t::unroll;
constexpr auto unroll1 = 2;
constexpr auto unroll2 = 4;
size_t in_offset = 0;
constexpr auto out_low_strides = low_stride * vl;
constexpr auto low_extra = low_stride * (vl * vl - 1);
constexpr auto high_extra = high_stride * (vl - 1);
auto in_strides = sizeof(vector<float, vl>);
auto out_strides = high_stride * sizeof(float);
constexpr auto high_dim = low_stride / high_stride;
asm("vsetvli zero, %[vl], e32, m1\n" ::[vl] "r"(vl));
while (count / unroll1) {
auto low_idx1 = in_offset / low_stride;
auto high_idx1 = in_offset / high_stride % high_dim;
auto out_ptr1 = output + in_offset + low_idx1 * low_extra +
high_idx1 * high_extra;
auto out_ptr2 = out_ptr1 + 1;
auto tmp = vl;
size_t i_idx = 0;
auto input1 = in_ptr;
auto input2 = in_ptr + in_stride;
while (tmp / unroll2) {
auto output1_1 = out_ptr1 + i_idx * out_low_strides;
asm volatile("vl1re32.v v1, (%[input1])\n"
"add %[input1], %[input1], %[in_strides]\n"
: [input1] "+r"(input1)
: [in_strides] "r"(in_strides));
auto output1_2 = output1_1 + out_low_strides;

asm volatile("vl1re32.v v2, (%[input1])\n"
"add %[input1], %[input1], %[in_strides]\n"
: [input1] "+r"(input1)
: [in_strides] "r"(in_strides));
auto output1_3 = output1_2 + out_low_strides;

asm volatile("vl1re32.v v3, (%[input1])\n"
"add %[input1], %[input1], %[in_strides]\n"
: [input1] "+r"(input1)
: [in_strides] "r"(in_strides));
auto output1_4 = output1_3 + out_low_strides;

asm volatile("vl1re32.v v4, (%[input1])\n"
"add %[input1], %[input1], %[in_strides]\n"
: [input1] "+r"(input1)
: [in_strides] "r"(in_strides));
auto output2_1 = out_ptr2 + i_idx * out_low_strides;

asm volatile("vl1re32.v v5, (%[input2])\n"
"add %[input2], %[input2], %[in_strides]\n"
: [input2] "+r"(input2)
: [in_strides] "r"(in_strides));
auto output2_2 = output2_1 + out_low_strides;

asm volatile("vl1re32.v v6, (%[input2])\n"
"add %[input2], %[input2], %[in_strides]\n"
: [input2] "+r"(input2)
: [in_strides] "r"(in_strides));
auto output2_3 = output2_2 + out_low_strides;

asm volatile("vl1re32.v v7, (%[input2])\n"
"add %[input2], %[input2], %[in_strides]\n"
: [input2] "+r"(input2)
: [in_strides] "r"(in_strides));
auto output2_4 = output2_3 + out_low_strides;

asm volatile("vl1re32.v v8, (%[input2])\n"
"add %[input2], %[input2], %[in_strides]\n"
: [input2] "+r"(input2)
: [in_strides] "r"(in_strides));
tmp -= unroll2;

asm volatile("vsse32.v v1, (%[output1_1]), %[out_strides]\n"
: [output1_1] "+r"(output1_1)
: [out_strides] "r"(out_strides));
i_idx += unroll2;

asm volatile("vsse32.v v2, (%[output1_2]), %[out_strides]\n"
: [output1_2] "+r"(output1_2)
: [out_strides] "r"(out_strides));

asm volatile("vsse32.v v3, (%[output1_3]), %[out_strides]\n"
: [output1_3] "+r"(output1_3)
: [out_strides] "r"(out_strides));

asm volatile("vsse32.v v4, (%[output1_4]), %[out_strides]\n"
: [output1_4] "+r"(output1_4)
: [out_strides] "r"(out_strides));

asm volatile("vsse32.v v5, (%[output2_1]), %[out_strides]\n"
: [output2_1] "+r"(output2_1)
: [out_strides] "r"(out_strides));

asm volatile("vsse32.v v6, (%[output2_2]), %[out_strides]\n"
: [output2_2] "+r"(output2_2)
: [out_strides] "r"(out_strides));

asm volatile("vsse32.v v7, (%[output2_3]), %[out_strides]\n"
: [output2_3] "+r"(output2_3)
: [out_strides] "r"(out_strides));

asm volatile("vsse32.v v8, (%[output2_4]), %[out_strides]\n"
: [output2_4] "+r"(output2_4)
: [out_strides] "r"(out_strides));
}
count -= unroll1;
in_ptr += in_stride * unroll1;
in_offset += unroll1;
}
}
};

#endif

template <>
struct u_unpack_1d_ranked<NTT_VLEN / 32, vector<float, NTT_VLEN / 32>, float,
Expand Down
7 changes: 4 additions & 3 deletions ntt/include/nncase/ntt/kernels/unpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ class unpack_impl<fixed_shape<InDims...>, fixed_shape<InElemDims...>, OutShape,
fixed_shape<InDims...>{}, fixed_strides<InStrides...>{});
constexpr auto low_axis = Axis1 < Axis2 ? Axis1 : Axis2;
constexpr auto high_axis = Axis1 < Axis2 ? Axis2 : Axis1;
if constexpr ((in_conti_dims == rank) && (high_axis == low_axis + 1)) {
if constexpr ((in_conti_dims == rank) && (high_axis == low_axis + 1) &&
(high_axis != (rank - 1))) {
auto pout = output.buffer().data();
auto count = input.shape().length();
constexpr auto in_strides =
Expand All @@ -100,8 +101,8 @@ class unpack_impl<fixed_shape<InDims...>, fixed_shape<InElemDims...>, OutShape,
auto elem_index = slice_index<elem_rank>(index, rank);
auto out_index = slice_index<rank>(index);
loop<axes.size()>([&](auto i) {
out_index[axes[i]] =
out_index[axes[i]] * TVec::shape()[i] + index[rank + i];
out_index[axes[i]] =
out_index[axes[i]] * TVec::shape()[i] + index[rank + i];
});
output(out_index) = input(in_index)(elem_index);
});
Expand Down
2 changes: 1 addition & 1 deletion ntt/test/benchmark_test/benchmark_ntt.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def __init__(self, target: str, bin_path: str):
'W': '4.3',
'NC': '6',
'CH': '6',
'HW': '4.3',
'HW': '6',
},
}

Expand Down

0 comments on commit 0b886aa

Please sign in to comment.