Skip to content

Commit

Permalink
add some opt for 2d pack
Browse files Browse the repository at this point in the history
  • Loading branch information
guodongliang committed Nov 18, 2024
1 parent 94112e4 commit fee0136
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 17 deletions.
134 changes: 130 additions & 4 deletions ntt/include/nncase/ntt/arch/x86_64/ukernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,13 @@ class u_pack<M, N, MStrides, true, float, vector<float, 8>> {
}
};

template <class TIn, class TOut, size_t... Axes>
class u_pack2d<TIn, TOut, float, vector<float, 8, 8>, Axes...> {
template <class TIn, class TOut>
class u_pack2d<true, TIn, TOut, float, vector<float, 8, 8>, (TIn::rank() - 2),
(TIn::rank() - 1)> {
public:
constexpr void operator()(const TIn &input, TOut &output) noexcept {
using TVec = vector<float, 8, 8>;
constexpr auto axes = std::array<size_t, sizeof...(Axes)>{Axes...};
constexpr size_t axes[2] = {TIn::rank() - 2, TIn::rank() - 1};
constexpr auto in_rank = TIn::rank();
constexpr auto out_rank = TOut::rank();
constexpr auto lanes = TVec::shape();
Expand All @@ -160,7 +161,7 @@ class u_pack2d<TIn, TOut, float, vector<float, 8, 8>, Axes...> {
apply(out_shape, [&](auto index) {
auto out_index = slice_index<out_rank>(index);
auto in_index = slice_index<in_rank>(index);
loop<axes.size()>([&](auto i) {
loop<2>([&](auto i) {
in_index[axes[i]] = in_index[axes[i]] * lanes[i];
});
auto in_ptr =
Expand All @@ -174,6 +175,131 @@ class u_pack2d<TIn, TOut, float, vector<float, 8, 8>, Axes...> {
}
};

template <class TIn, class TOut, size_t... Axes>
class u_pack2d<true, TIn, TOut, float, vector<float, 8, 8>, Axes...> {
public:
constexpr void operator()(const TIn &input, TOut &output) noexcept {
using TVec = vector<float, 8, 8>;
constexpr auto axes = std::array<size_t, sizeof...(Axes)>{Axes...};
constexpr auto in_rank = TIn::rank();
constexpr auto out_rank = TOut::rank();
constexpr auto lanes = TVec::shape();
auto out_shape = output.shape();

ranked_shape<out_rank> domain{};
for (size_t i = 0; i < out_rank; i++) {
domain[i] = out_shape[i];
}
ranked_shape<in_rank> inner_domain{};
ranked_shape<in_rank> outer_domain{};

auto outer_index = slice_index<axes[0]>(domain);
auto packed_index = slice_index<sizeof...(Axes)>(domain, axes[0]);
auto inner_index =
slice_index<out_rank - (axes[1] + 1)>(domain, axes[1] + 1);
auto inner_size = inner_index.length();

ntt::apply(outer_index, [&](auto index) {
for (size_t i = 0; i < axes[0]; i++) {
inner_domain[i] = index[i];
outer_domain[i] = index[i];
}
for (size_t i = 0; i < packed_index[0]; i++) {
outer_domain[axes[0]] = i;
auto outer_ptr_keep =
reinterpret_cast<float *>(&output(outer_domain));
for (size_t j = 0; j < lanes[0]; j++) {
inner_domain[axes[0]] = i * lanes[0] + j;
auto outer_ptr = outer_ptr_keep + j * lanes[0];

for (size_t k = 0; k < packed_index[1]; k++) {
inner_domain[axes[1]] = k * lanes[1];
auto input_ptr = reinterpret_cast<const float *>(
&input(inner_domain));

for (size_t l = 0; l < inner_size / lanes[1]; l++) {
__m256 row0 = _mm256_loadu_ps(
&input_ptr[0 * inner_size + l * lanes[1]]);
__m256 row1 = _mm256_loadu_ps(
&input_ptr[1 * inner_size + l * lanes[1]]);
__m256 row2 = _mm256_loadu_ps(
&input_ptr[2 * inner_size + l * lanes[1]]);
__m256 row3 = _mm256_loadu_ps(
&input_ptr[3 * inner_size + l * lanes[1]]);
__m256 row4 = _mm256_loadu_ps(
&input_ptr[4 * inner_size + l * lanes[1]]);
__m256 row5 = _mm256_loadu_ps(
&input_ptr[5 * inner_size + l * lanes[1]]);
__m256 row6 = _mm256_loadu_ps(
&input_ptr[6 * inner_size + l * lanes[1]]);
__m256 row7 = _mm256_loadu_ps(
&input_ptr[7 * inner_size + l * lanes[1]]);

__m256 t0 = _mm256_unpacklo_ps(row0, row1);
__m256 t1 = _mm256_unpackhi_ps(row0, row1);
__m256 t2 = _mm256_unpacklo_ps(row2, row3);
__m256 t3 = _mm256_unpackhi_ps(row2, row3);
__m256 t4 = _mm256_unpacklo_ps(row4, row5);
__m256 t5 = _mm256_unpackhi_ps(row4, row5);
__m256 t6 = _mm256_unpacklo_ps(row6, row7);
__m256 t7 = _mm256_unpackhi_ps(row6, row7);

__m256 u0 = _mm256_shuffle_ps(
t0, t2, 0x44); // 0x44 -> 01000100
__m256 u1 = _mm256_shuffle_ps(
t0, t2, 0xEE); // 0xEE -> 11101110
__m256 u2 = _mm256_shuffle_ps(t1, t3, 0x44);
__m256 u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
__m256 u4 = _mm256_shuffle_ps(t4, t6, 0x44);
__m256 u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
__m256 u6 = _mm256_shuffle_ps(t5, t7, 0x44);
__m256 u7 = _mm256_shuffle_ps(t5, t7, 0xEE);

row0 = _mm256_permute2f128_ps(
u0, u4, 0x20); // 0x20 -> 00100000
row1 = _mm256_permute2f128_ps(u1, u5, 0x20);
row2 = _mm256_permute2f128_ps(u2, u6, 0x20);
row3 = _mm256_permute2f128_ps(u3, u7, 0x20);
row4 = _mm256_permute2f128_ps(
u0, u4, 0x31); // 0x31 -> 00110001
row5 = _mm256_permute2f128_ps(u1, u5, 0x31);
row6 = _mm256_permute2f128_ps(u2, u6, 0x31);
row7 = _mm256_permute2f128_ps(u3, u7, 0x31);

_mm256_storeu_ps(
&outer_ptr[(l * lanes[0] + 0) * lanes.length()],
row0);
_mm256_storeu_ps(
&outer_ptr[(l * lanes[0] + 1) * lanes.length()],
row1);
_mm256_storeu_ps(
&outer_ptr[(l * lanes[0] + 2) * lanes.length()],
row2);
_mm256_storeu_ps(
&outer_ptr[(l * lanes[0] + 3) * lanes.length()],
row3);
_mm256_storeu_ps(
&outer_ptr[(l * lanes[0] + 4) * lanes.length()],
row4);
_mm256_storeu_ps(
&outer_ptr[(l * lanes[0] + 5) * lanes.length()],
row5);
_mm256_storeu_ps(
&outer_ptr[(l * lanes[0] + 6) * lanes.length()],
row6);
_mm256_storeu_ps(
&outer_ptr[(l * lanes[0] + 7) * lanes.length()],
row7);
}

outer_ptr += (inner_size * lanes.length());
}
}
}
});
}
};

// reduce
template <reduce_op Op, class T> struct u_reduce_policy<Op, T, true> {
static constexpr size_t unroll = 8;
Expand Down
7 changes: 2 additions & 5 deletions ntt/include/nncase/ntt/kernels/pack.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,8 @@ template <class TIn, class TOut, size_t... Axes> class pack_impl {
auto conti_dims_output =
contiguous_dims(output.shape(), output.strides());

if (sizeof...(Axes) == 2 && axes[0] == in_rank - 2 &&
axes[1] == in_rank - 1 && conti_dims_input >= 2 &&
conti_dims_output >= 2 &&
input.shape()[in_rank - 2] % lanes[0] == 0 &&
input.shape()[in_rank - 1] % lanes[1] == 0) {
if (sizeof...(Axes) == 2 && conti_dims_input == in_rank &&
conti_dims_output == out_rank) {
ntt::u_pack2d<TIn, TOut, Axes...>(input, output);

} else {
Expand Down
19 changes: 16 additions & 3 deletions ntt/include/nncase/ntt/ukernels/u_pack.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class u_pack {
}
};

template <class TIn, class TOut, class TElem, class TVec, size_t... Axes>
template <bool Arch, class TIn, class TOut, class TElem, class TVec,
size_t... Axes>
class u_pack2d {
public:
constexpr void operator()(const TIn &input, TOut &output) noexcept {
Expand Down Expand Up @@ -88,8 +89,20 @@ template <class TIn, class TOut, size_t... Axes>
constexpr void u_pack2d(const TIn &input, TOut &output) noexcept {
using TElem = typename TIn::element_type;
using TVec = typename std::decay_t<TOut>::element_type;
ukernels::u_pack2d<TIn, TOut, TElem, TVec, Axes...> impl;
impl(input, output);
constexpr auto axes = std::array<size_t, sizeof...(Axes)>{Axes...};
constexpr auto in_rank = TIn::rank();

auto inner_size = 1;
for (size_t i = axes[1] + 1; i < in_rank; i++) {
inner_size *= input.shape()[i];
}
if (inner_size != TVec::shape()[1]) {
ukernels::u_pack2d<false, TIn, TOut, TElem, TVec, Axes...> impl;
impl(input, output);
} else {
ukernels::u_pack2d<true, TIn, TOut, TElem, TVec, Axes...> impl;
impl(input, output);
}
}

} // namespace nncase::ntt
10 changes: 5 additions & 5 deletions ntt/test/benchmark_test/benchmark_ntt_pack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ void benchmark_ntt_pack(const std::string &mode, const size_t run_size) {
ntt::tensor<ElementType,
ntt::fixed_shape<N / P0, C / P1, H / P2, W / P3>>;

tensor_type1 ntt_input;
tensor_type2 ntt_output;
alignas(32) tensor_type1 ntt_input;
alignas(32) tensor_type2 ntt_output;
NttTest::init_tensor(ntt_input, -10.f, 10.f);

// warm up
Expand Down Expand Up @@ -97,9 +97,9 @@ int main(int argc, char *argv[]) {
benchmark_ntt_pack<ntt::vector<float, P>, 2, 8 * P, 2, 4, 1>("C", 2000);
benchmark_ntt_pack<ntt::vector<float, P>, 2, 2, 8 * P, 8, 2>("H", 2000);
benchmark_ntt_pack<ntt::vector<float, P>, 2, 2, 2, 8 * P, 3>("W", 2000);
benchmark_ntt_pack<ntt::vector<float, P, P>, 8 * P, 8 * P, 2, 2, 0, 1>(
"NC", 2000);
benchmark_ntt_pack<ntt::vector<float, P, P>, 2, 8 * P, 8 * P, 2, 1, 2>(
benchmark_ntt_pack<ntt::vector<float, P, P>, 8 * P, 8 * P, 2, 4, 0, 1>("NC",
1);
benchmark_ntt_pack<ntt::vector<float, P, P>, 2, 8 * P, 8 * P, 8, 1, 2>(
"CH", 2000);
benchmark_ntt_pack<ntt::vector<float, P, P>, 4, 4, 8 * P, 8 * P, 2, 3>(
"HW", 2000);
Expand Down

0 comments on commit fee0136

Please sign in to comment.