From 0b886aa892a0c59f2d36e6bb61706f6fdcb65753 Mon Sep 17 00:00:00 2001 From: zhangyang2057 Date: Fri, 22 Nov 2024 09:11:26 +0800 Subject: [PATCH] Optimize unpack 2D. --- .../nncase/ntt/arch/riscv64/ukernels.h | 126 ++++++++++++++++++ ntt/include/nncase/ntt/kernels/unpack.h | 7 +- ntt/test/benchmark_test/benchmark_ntt.py | 2 +- 3 files changed, 131 insertions(+), 4 deletions(-) diff --git a/ntt/include/nncase/ntt/arch/riscv64/ukernels.h b/ntt/include/nncase/ntt/arch/riscv64/ukernels.h index 845707f79..02e9c3d00 100644 --- a/ntt/include/nncase/ntt/arch/riscv64/ukernels.h +++ b/ntt/include/nncase/ntt/arch/riscv64/ukernels.h @@ -666,6 +666,7 @@ struct u_unpack_1d_fixed class u_unpack_2d_fixed +class u_unpack_2d_fixed { + public: + void operator()(const T1 &input, size_t in_stride, float *output, + size_t count) noexcept { + constexpr size_t vl = NTT_VLEN / 32; + auto in_ptr = input.buffer().data(); + // using policy_t = + // u_unpack_policy, float, + // true>; + // constexpr auto unroll = policy_t::unroll; + constexpr auto unroll1 = 2; + constexpr auto unroll2 = 4; + size_t in_offset = 0; + constexpr auto out_low_strides = low_stride * vl; + constexpr auto low_extra = low_stride * (vl * vl - 1); + constexpr auto high_extra = high_stride * (vl - 1); + auto in_strides = sizeof(vector); + auto out_strides = high_stride * sizeof(float); + constexpr auto high_dim = low_stride / high_stride; + asm("vsetvli zero, %[vl], e32, m1\n" ::[vl] "r"(vl)); + while (count / unroll1) { + auto low_idx1 = in_offset / low_stride; + auto high_idx1 = in_offset / high_stride % high_dim; + auto out_ptr1 = output + in_offset + low_idx1 * low_extra + + high_idx1 * high_extra; + auto out_ptr2 = out_ptr1 + 1; + auto tmp = vl; + size_t i_idx = 0; + auto input1 = in_ptr; + auto input2 = in_ptr + in_stride; + while (tmp / unroll2) { + auto output1_1 = out_ptr1 + i_idx * out_low_strides; + asm volatile("vl1re32.v v1, (%[input1])\n" + "add %[input1], %[input1], %[in_strides]\n" + : [input1] "+r"(input1) + : [in_strides] "r"(in_strides)); + auto output1_2 = output1_1 + out_low_strides; + + asm volatile("vl1re32.v v2, (%[input1])\n" + "add %[input1], %[input1], %[in_strides]\n" + : [input1] "+r"(input1) + : [in_strides] "r"(in_strides)); + auto output1_3 = output1_2 + out_low_strides; + + asm volatile("vl1re32.v v3, (%[input1])\n" + "add %[input1], %[input1], %[in_strides]\n" + : [input1] "+r"(input1) + : [in_strides] "r"(in_strides)); + auto output1_4 = output1_3 + out_low_strides; + + asm volatile("vl1re32.v v4, (%[input1])\n" + "add %[input1], %[input1], %[in_strides]\n" + : [input1] "+r"(input1) + : [in_strides] "r"(in_strides)); + auto output2_1 = out_ptr2 + i_idx * out_low_strides; + + asm volatile("vl1re32.v v5, (%[input2])\n" + "add %[input2], %[input2], %[in_strides]\n" + : [input2] "+r"(input2) + : [in_strides] "r"(in_strides)); + auto output2_2 = output2_1 + out_low_strides; + + asm volatile("vl1re32.v v6, (%[input2])\n" + "add %[input2], %[input2], %[in_strides]\n" + : [input2] "+r"(input2) + : [in_strides] "r"(in_strides)); + auto output2_3 = output2_2 + out_low_strides; + + asm volatile("vl1re32.v v7, (%[input2])\n" + "add %[input2], %[input2], %[in_strides]\n" + : [input2] "+r"(input2) + : [in_strides] "r"(in_strides)); + auto output2_4 = output2_3 + out_low_strides; + + asm volatile("vl1re32.v v8, (%[input2])\n" + "add %[input2], %[input2], %[in_strides]\n" + : [input2] "+r"(input2) + : [in_strides] "r"(in_strides)); + tmp -= unroll2; + + asm volatile("vsse32.v v1, (%[output1_1]), %[out_strides]\n" + : [output1_1] "+r"(output1_1) + : [out_strides] "r"(out_strides)); + i_idx += unroll2; + + asm volatile("vsse32.v v2, (%[output1_2]), %[out_strides]\n" + : [output1_2] "+r"(output1_2) + : [out_strides] "r"(out_strides)); + + asm volatile("vsse32.v v3, (%[output1_3]), %[out_strides]\n" + : [output1_3] "+r"(output1_3) + : [out_strides] "r"(out_strides)); + + asm volatile("vsse32.v v4, (%[output1_4]), %[out_strides]\n" + : [output1_4] "+r"(output1_4) + : [out_strides] "r"(out_strides)); + + asm volatile("vsse32.v v5, (%[output2_1]), %[out_strides]\n" + : [output2_1] "+r"(output2_1) + : [out_strides] "r"(out_strides)); + + asm volatile("vsse32.v v6, (%[output2_2]), %[out_strides]\n" + : [output2_2] "+r"(output2_2) + : [out_strides] "r"(out_strides)); + + asm volatile("vsse32.v v7, (%[output2_3]), %[out_strides]\n" + : [output2_3] "+r"(output2_3) + : [out_strides] "r"(out_strides)); + + asm volatile("vsse32.v v8, (%[output2_4]), %[out_strides]\n" + : [output2_4] "+r"(output2_4) + : [out_strides] "r"(out_strides)); + } + count -= unroll1; + in_ptr += in_stride * unroll1; + in_offset += unroll1; + } + } +}; + +#endif template <> struct u_unpack_1d_ranked, float, diff --git a/ntt/include/nncase/ntt/kernels/unpack.h b/ntt/include/nncase/ntt/kernels/unpack.h index b687dc38f..d06fcdebd 100644 --- a/ntt/include/nncase/ntt/kernels/unpack.h +++ b/ntt/include/nncase/ntt/kernels/unpack.h @@ -80,7 +80,8 @@ class unpack_impl, fixed_shape, OutShape, fixed_shape{}, fixed_strides{}); constexpr auto low_axis = Axis1 < Axis2 ? Axis1 : Axis2; constexpr auto high_axis = Axis1 < Axis2 ? Axis2 : Axis1; - if constexpr ((in_conti_dims == rank) && (high_axis == low_axis + 1)) { + if constexpr ((in_conti_dims == rank) && (high_axis == low_axis + 1) && + (high_axis != (rank - 1))) { auto pout = output.buffer().data(); auto count = input.shape().length(); constexpr auto in_strides = @@ -100,8 +101,8 @@ class unpack_impl, fixed_shape, OutShape, auto elem_index = slice_index(index, rank); auto out_index = slice_index(index); loop([&](auto i) { - out_index[axes[i]] = - out_index[axes[i]] * TVec::shape()[i] + index[rank + i]; + out_index[axes[i]] = + out_index[axes[i]] * TVec::shape()[i] + index[rank + i]; }); output(out_index) = input(in_index)(elem_index); }); diff --git a/ntt/test/benchmark_test/benchmark_ntt.py b/ntt/test/benchmark_test/benchmark_ntt.py index c3e8de772..6287bdad5 100644 --- a/ntt/test/benchmark_test/benchmark_ntt.py +++ b/ntt/test/benchmark_test/benchmark_ntt.py @@ -395,7 +395,7 @@ def __init__(self, target: str, bin_path: str): 'W': '4.3', 'NC': '6', 'CH': '6', - 'HW': '4.3', + 'HW': '6', }, }