From fee0136ea85009c64eeaac123ad659c19e62c44d Mon Sep 17 00:00:00 2001 From: guodongliang Date: Mon, 18 Nov 2024 11:58:06 +0800 Subject: [PATCH] add some opt for 2d pack --- ntt/include/nncase/ntt/arch/x86_64/ukernels.h | 134 +++++++++++++++++- ntt/include/nncase/ntt/kernels/pack.h | 7 +- ntt/include/nncase/ntt/ukernels/u_pack.h | 19 ++- .../benchmark_test/benchmark_ntt_pack.cpp | 10 +- 4 files changed, 153 insertions(+), 17 deletions(-) diff --git a/ntt/include/nncase/ntt/arch/x86_64/ukernels.h b/ntt/include/nncase/ntt/arch/x86_64/ukernels.h index c63146b6b5..189da7dbe2 100644 --- a/ntt/include/nncase/ntt/arch/x86_64/ukernels.h +++ b/ntt/include/nncase/ntt/arch/x86_64/ukernels.h @@ -146,12 +146,13 @@ class u_pack> { } }; -template -class u_pack2d, Axes...> { +template +class u_pack2d, (TIn::rank() - 2), + (TIn::rank() - 1)> { public: constexpr void operator()(const TIn &input, TOut &output) noexcept { using TVec = vector; - constexpr auto axes = std::array{Axes...}; + constexpr size_t axes[2] = {TIn::rank() - 2, TIn::rank() - 1}; constexpr auto in_rank = TIn::rank(); constexpr auto out_rank = TOut::rank(); constexpr auto lanes = TVec::shape(); @@ -160,7 +161,7 @@ class u_pack2d, Axes...> { apply(out_shape, [&](auto index) { auto out_index = slice_index(index); auto in_index = slice_index(index); - loop([&](auto i) { + loop<2>([&](auto i) { in_index[axes[i]] = in_index[axes[i]] * lanes[i]; }); auto in_ptr = @@ -174,6 +175,131 @@ class u_pack2d, Axes...> { } }; +template +class u_pack2d, Axes...> { + public: + constexpr void operator()(const TIn &input, TOut &output) noexcept { + using TVec = vector; + constexpr auto axes = std::array{Axes...}; + constexpr auto in_rank = TIn::rank(); + constexpr auto out_rank = TOut::rank(); + constexpr auto lanes = TVec::shape(); + auto out_shape = output.shape(); + + ranked_shape domain{}; + for (size_t i = 0; i < out_rank; i++) { + domain[i] = out_shape[i]; + } + ranked_shape inner_domain{}; + ranked_shape outer_domain{}; + + auto outer_index = slice_index(domain); + auto packed_index = slice_index(domain, axes[0]); + auto inner_index = + slice_index(domain, axes[1] + 1); + auto inner_size = inner_index.length(); + + ntt::apply(outer_index, [&](auto index) { + for (size_t i = 0; i < axes[0]; i++) { + inner_domain[i] = index[i]; + outer_domain[i] = index[i]; + } + for (size_t i = 0; i < packed_index[0]; i++) { + outer_domain[axes[0]] = i; + auto outer_ptr_keep = + reinterpret_cast(&output(outer_domain)); + for (size_t j = 0; j < lanes[0]; j++) { + inner_domain[axes[0]] = i * lanes[0] + j; + auto outer_ptr = outer_ptr_keep + j * lanes[0]; + + for (size_t k = 0; k < packed_index[1]; k++) { + inner_domain[axes[1]] = k * lanes[1]; + auto input_ptr = reinterpret_cast( + &input(inner_domain)); + + for (size_t l = 0; l < inner_size / lanes[1]; l++) { + __m256 row0 = _mm256_loadu_ps( + &input_ptr[0 * inner_size + l * lanes[1]]); + __m256 row1 = _mm256_loadu_ps( + &input_ptr[1 * inner_size + l * lanes[1]]); + __m256 row2 = _mm256_loadu_ps( + &input_ptr[2 * inner_size + l * lanes[1]]); + __m256 row3 = _mm256_loadu_ps( + &input_ptr[3 * inner_size + l * lanes[1]]); + __m256 row4 = _mm256_loadu_ps( + &input_ptr[4 * inner_size + l * lanes[1]]); + __m256 row5 = _mm256_loadu_ps( + &input_ptr[5 * inner_size + l * lanes[1]]); + __m256 row6 = _mm256_loadu_ps( + &input_ptr[6 * inner_size + l * lanes[1]]); + __m256 row7 = _mm256_loadu_ps( + &input_ptr[7 * inner_size + l * lanes[1]]); + + __m256 t0 = _mm256_unpacklo_ps(row0, row1); + __m256 t1 = _mm256_unpackhi_ps(row0, row1); + __m256 t2 = _mm256_unpacklo_ps(row2, row3); + __m256 t3 = _mm256_unpackhi_ps(row2, row3); + __m256 t4 = _mm256_unpacklo_ps(row4, row5); + __m256 t5 = _mm256_unpackhi_ps(row4, row5); + __m256 t6 = _mm256_unpacklo_ps(row6, row7); + __m256 t7 = _mm256_unpackhi_ps(row6, row7); + + __m256 u0 = _mm256_shuffle_ps( + t0, t2, 0x44); // 0x44 -> 01000100 + __m256 u1 = _mm256_shuffle_ps( + t0, t2, 0xEE); // 0xEE -> 11101110 + __m256 u2 = _mm256_shuffle_ps(t1, t3, 0x44); + __m256 u3 = _mm256_shuffle_ps(t1, t3, 0xEE); + __m256 u4 = _mm256_shuffle_ps(t4, t6, 0x44); + __m256 u5 = _mm256_shuffle_ps(t4, t6, 0xEE); + __m256 u6 = _mm256_shuffle_ps(t5, t7, 0x44); + __m256 u7 = _mm256_shuffle_ps(t5, t7, 0xEE); + + row0 = _mm256_permute2f128_ps( + u0, u4, 0x20); // 0x20 -> 00100000 + row1 = _mm256_permute2f128_ps(u1, u5, 0x20); + row2 = _mm256_permute2f128_ps(u2, u6, 0x20); + row3 = _mm256_permute2f128_ps(u3, u7, 0x20); + row4 = _mm256_permute2f128_ps( + u0, u4, 0x31); // 0x31 -> 00110001 + row5 = _mm256_permute2f128_ps(u1, u5, 0x31); + row6 = _mm256_permute2f128_ps(u2, u6, 0x31); + row7 = _mm256_permute2f128_ps(u3, u7, 0x31); + + _mm256_storeu_ps( + &outer_ptr[(l * lanes[0] + 0) * lanes.length()], + row0); + _mm256_storeu_ps( + &outer_ptr[(l * lanes[0] + 1) * lanes.length()], + row1); + _mm256_storeu_ps( + &outer_ptr[(l * lanes[0] + 2) * lanes.length()], + row2); + _mm256_storeu_ps( + &outer_ptr[(l * lanes[0] + 3) * lanes.length()], + row3); + _mm256_storeu_ps( + &outer_ptr[(l * lanes[0] + 4) * lanes.length()], + row4); + _mm256_storeu_ps( + &outer_ptr[(l * lanes[0] + 5) * lanes.length()], + row5); + _mm256_storeu_ps( + &outer_ptr[(l * lanes[0] + 6) * lanes.length()], + row6); + _mm256_storeu_ps( + &outer_ptr[(l * lanes[0] + 7) * lanes.length()], + row7); + } + + outer_ptr += (inner_size * lanes.length()); + } + } + } + }); + } +}; + // reduce template struct u_reduce_policy { static constexpr size_t unroll = 8; diff --git a/ntt/include/nncase/ntt/kernels/pack.h b/ntt/include/nncase/ntt/kernels/pack.h index 4f11f8df95..f742a4286e 100644 --- a/ntt/include/nncase/ntt/kernels/pack.h +++ b/ntt/include/nncase/ntt/kernels/pack.h @@ -38,11 +38,8 @@ template class pack_impl { auto conti_dims_output = contiguous_dims(output.shape(), output.strides()); - if (sizeof...(Axes) == 2 && axes[0] == in_rank - 2 && - axes[1] == in_rank - 1 && conti_dims_input >= 2 && - conti_dims_output >= 2 && - input.shape()[in_rank - 2] % lanes[0] == 0 && - input.shape()[in_rank - 1] % lanes[1] == 0) { + if (sizeof...(Axes) == 2 && conti_dims_input == in_rank && + conti_dims_output == out_rank) { ntt::u_pack2d(input, output); } else { diff --git a/ntt/include/nncase/ntt/ukernels/u_pack.h b/ntt/include/nncase/ntt/ukernels/u_pack.h index 5a8482318c..82186ac26a 100644 --- a/ntt/include/nncase/ntt/ukernels/u_pack.h +++ b/ntt/include/nncase/ntt/ukernels/u_pack.h @@ -39,7 +39,8 @@ class u_pack { } }; -template +template class u_pack2d { public: constexpr void operator()(const TIn &input, TOut &output) noexcept { @@ -88,8 +89,20 @@ template constexpr void u_pack2d(const TIn &input, TOut &output) noexcept { using TElem = typename TIn::element_type; using TVec = typename std::decay_t::element_type; - ukernels::u_pack2d impl; - impl(input, output); + constexpr auto axes = std::array{Axes...}; + constexpr auto in_rank = TIn::rank(); + + auto inner_size = 1; + for (size_t i = axes[1] + 1; i < in_rank; i++) { + inner_size *= input.shape()[i]; + } + if (inner_size != TVec::shape()[1]) { + ukernels::u_pack2d impl; + impl(input, output); + } else { + ukernels::u_pack2d impl; + impl(input, output); + } } } // namespace nncase::ntt diff --git a/ntt/test/benchmark_test/benchmark_ntt_pack.cpp b/ntt/test/benchmark_test/benchmark_ntt_pack.cpp index 3cfa7eeba1..95ce09accd 100644 --- a/ntt/test/benchmark_test/benchmark_ntt_pack.cpp +++ b/ntt/test/benchmark_test/benchmark_ntt_pack.cpp @@ -51,8 +51,8 @@ void benchmark_ntt_pack(const std::string &mode, const size_t run_size) { ntt::tensor>; - tensor_type1 ntt_input; - tensor_type2 ntt_output; + alignas(32) tensor_type1 ntt_input; + alignas(32) tensor_type2 ntt_output; NttTest::init_tensor(ntt_input, -10.f, 10.f); // warm up @@ -97,9 +97,9 @@ int main(int argc, char *argv[]) { benchmark_ntt_pack, 2, 8 * P, 2, 4, 1>("C", 2000); benchmark_ntt_pack, 2, 2, 8 * P, 8, 2>("H", 2000); benchmark_ntt_pack, 2, 2, 2, 8 * P, 3>("W", 2000); - benchmark_ntt_pack, 8 * P, 8 * P, 2, 2, 0, 1>( - "NC", 2000); - benchmark_ntt_pack, 2, 8 * P, 8 * P, 2, 1, 2>( + benchmark_ntt_pack, 8 * P, 8 * P, 2, 4, 0, 1>("NC", + 1); + benchmark_ntt_pack, 2, 8 * P, 8 * P, 8, 1, 2>( "CH", 2000); benchmark_ntt_pack, 4, 4, 8 * P, 8 * P, 2, 3>( "HW", 2000);