Skip to content

Commit

Permalink
add 2d unpack opt
Browse files Browse the repository at this point in the history
  • Loading branch information
guodongliang committed Nov 22, 2024
1 parent dd98ff4 commit 806a1d4
Showing 1 changed file with 123 additions and 5 deletions.
128 changes: 123 additions & 5 deletions ntt/include/nncase/ntt/arch/x86_64/ukernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once
#include "../../ukernels.h"
#include "nncase/ntt/vector.h"
#include <iostream>

namespace nncase::ntt::ukernels {

Expand Down Expand Up @@ -137,11 +138,10 @@ class u_pack<M, N, MStrides, true, float, vector<float, 8>> {
};

template <class TIn, class TOut, size_t... Axes>
requires(sizeof...(Axes) > 0 &&
(std::get<sizeof...(Axes) - 1>(std::array<size_t, sizeof...(Axes)>{
Axes...}) ==
(TIn::rank() - 1))) class u_pack2d<true, TIn, TOut, float,
vector<float, 8, 8>, Axes...> {
requires(sizeof...(Axes) > 0 &&
(std::get<sizeof...(Axes) - 1>(std::array<size_t, sizeof...(Axes)>{
Axes...}) == (TIn::rank() - 1)))
class u_pack2d<true, TIn, TOut, float, vector<float, 8, 8>, Axes...> {
public:
constexpr void operator()(const TIn &input, TOut &output) noexcept {
using TVec = vector<float, 8, 8>;
Expand Down Expand Up @@ -426,6 +426,124 @@ class u_unpack_1d_fixed<axis_stride, 8, T1, float, true, PackAxis> {
}
};

template <size_t low_axis_stride, size_t high_axis_stride, class TIn,
size_t Axis1, size_t Axis2>
class u_unpack_2d_fixed<low_axis_stride, 8, high_axis_stride, 8, TIn, float,
true, Axis1, Axis2> {
public:
void operator()(const TIn &input, size_t input_stride, float *output,
size_t count) noexcept {
using TVec = vector<float, 8, 8>;
constexpr auto axes = std::array<size_t, 2>{Axis1, Axis2};
constexpr auto in_rank = TIn::rank();
auto in_shape = input.shape();

ranked_shape<in_rank> domain{};
for (size_t i = 0; i < in_rank; i++) {
domain[i] = in_shape[i];
}
ranked_shape<in_rank> inner_domain{};

auto packed_index = slice_index<2>(domain, axes[0]);
auto inner_index =
slice_index<in_rank - (axes[1] + 1)>(domain, axes[1] + 1);
auto inner_size = inner_index.length();

ranked_shape<Axis2> tile_domain{};
for (size_t i = 0; i < Axis2; i++) {
tile_domain[i] = in_shape[i];
}

auto dst = output;
if (inner_size % TVec::shape()[1] != 0) {
ukernels::u_unpack_2d_fixed<low_axis_stride, 8, high_axis_stride, 8,
TIn, float, false, Axis1, Axis2>
impl;
impl(input, input_stride, output, count);
} else {
ntt::apply(tile_domain, [&](auto index) {
for (size_t i = 0; i < Axis2; i++) {
inner_domain[i] = index[i];
}
auto src =
reinterpret_cast<const float *>(&input(inner_domain));
dst =
output + linear_offset(inner_domain, input.strides()) * 64;
for (size_t i = 0; i < 8; i++) {
auto st_offset_i = i * packed_index[1] * 8 * inner_size;
for (size_t j = 0; j < packed_index[1]; j++) {
auto st_offset_j = j * inner_size * 8;
auto ld_offset_j = src + j * inner_size * 64;
for (size_t k = 0; k < inner_size / 8; k++) {
auto st_offset = st_offset_i + st_offset_j + k * 8;
auto ld_offset = ld_offset_j + k * 512;
__m256 row0 = _mm256_load_ps(ld_offset + 0 * 64);
__m256 row1 = _mm256_load_ps(ld_offset + 1 * 64);
__m256 row2 = _mm256_load_ps(ld_offset + 2 * 64);
__m256 row3 = _mm256_load_ps(ld_offset + 3 * 64);
__m256 row4 = _mm256_load_ps(ld_offset + 4 * 64);
__m256 row5 = _mm256_load_ps(ld_offset + 5 * 64);
__m256 row6 = _mm256_load_ps(ld_offset + 6 * 64);
__m256 row7 = _mm256_load_ps(ld_offset + 7 * 64);

__m256 t0 = _mm256_unpacklo_ps(row0, row1);
__m256 t1 = _mm256_unpackhi_ps(row0, row1);
__m256 t2 = _mm256_unpacklo_ps(row2, row3);
__m256 t3 = _mm256_unpackhi_ps(row2, row3);
__m256 t4 = _mm256_unpacklo_ps(row4, row5);
__m256 t5 = _mm256_unpackhi_ps(row4, row5);
__m256 t6 = _mm256_unpacklo_ps(row6, row7);
__m256 t7 = _mm256_unpackhi_ps(row6, row7);

__m256 u0 = _mm256_shuffle_ps(
t0, t2, 0x44); // 0x44 -> 01000100
__m256 u1 = _mm256_shuffle_ps(
t0, t2, 0xEE); // 0xEE -> 11101110
__m256 u2 = _mm256_shuffle_ps(t1, t3, 0x44);
__m256 u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
__m256 u4 = _mm256_shuffle_ps(t4, t6, 0x44);
__m256 u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
__m256 u6 = _mm256_shuffle_ps(t5, t7, 0x44);
__m256 u7 = _mm256_shuffle_ps(t5, t7, 0xEE);

row0 = _mm256_permute2f128_ps(
u0, u4,
0x20); // 0x20 -> 00100000
row1 = _mm256_permute2f128_ps(u1, u5, 0x20);
row2 = _mm256_permute2f128_ps(u2, u6, 0x20);
row3 = _mm256_permute2f128_ps(u3, u7, 0x20);
row4 = _mm256_permute2f128_ps(
u0, u4,
0x31); // 0x31 -> 00110001
row5 = _mm256_permute2f128_ps(u1, u5, 0x31);
row6 = _mm256_permute2f128_ps(u2, u6, 0x31);
row7 = _mm256_permute2f128_ps(u3, u7, 0x31);

_mm256_store_ps(&dst[0 * inner_size + st_offset],
row0);
_mm256_store_ps(&dst[1 * inner_size + st_offset],
row1);
_mm256_store_ps(&dst[2 * inner_size + st_offset],
row2);
_mm256_store_ps(&dst[3 * inner_size + st_offset],
row3);
_mm256_store_ps(&dst[4 * inner_size + st_offset],
row4);
_mm256_store_ps(&dst[5 * inner_size + st_offset],
row5);
_mm256_store_ps(&dst[6 * inner_size + st_offset],
row6);
_mm256_store_ps(&dst[7 * inner_size + st_offset],
row7);
}
}
src = src + 8;
}
});
}
}
};

// reduce
template <reduce_op Op, class T> struct u_reduce_policy<Op, T, true> {
static constexpr size_t unroll = 8;
Expand Down

0 comments on commit 806a1d4

Please sign in to comment.