Skip to content

Commit

Permalink
support matmul transpose b
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Sep 20, 2024
1 parent 98ca0ef commit 46c948c
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 113 deletions.
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/Evaluator/CPU/PackedMatMul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public IValue Visit(IEvaluateContext context, PackedMatMul target)
var outShape = Array.Empty<int>();
var axes = Array.Empty<int>();
var (lm, lk) = target.TransposeA ? (lhs.Rank - target.RhsPackedAxes.Count - 1, lhs.Rank - target.RhsPackedAxes.Count - 2) : (lhs.Rank - target.LhsPackedAxes.Count - 2, lhs.Rank - target.LhsPackedAxes.Count - 1);
var (rk, rn) = target.TransposeB ? (rhs.Rank - target.RhsPackedAxes.Count - 1, rhs.Rank - target.RhsPackedAxes.Count - 2) : (rhs.Rank - target.LhsPackedAxes.Count - 2, rhs.Rank - target.LhsPackedAxes.Count - 1);
var (rk, rn) = target.TransposeB ? (rhs.Rank - target.RhsPackedAxes.Count - 1, rhs.Rank - target.RhsPackedAxes.Count - 2) : (rhs.Rank - target.RhsPackedAxes.Count - 2, rhs.Rank - target.RhsPackedAxes.Count - 1);
if (target.LhsPackedAxes.Count == 0 && target.RhsPackedAxes.Count == 1)
{
outLanes = new[] { (int)rhs.Shape[^1] };
Expand Down
10 changes: 5 additions & 5 deletions modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -406,23 +406,23 @@ void AddCandidate(PackKind lhsPack, PackKind rhsPack, bool transA = false, bool
AddCandidate(PackKind.K, PackKind.K);

// only pack A's m
// AddCandidate(new[] { lhsShape.Length - 2 }, Array.Empty<int>(), new[] { Lane }, Array.Empty<int>());
AddCandidate(PackKind.M, PackKind.None);

// only pack B's n
AddCandidate(PackKind.None, PackKind.N, transB: rhs is Const);
if (Rank > 1)
{
// pack A's m and B's n, when B is const, force transpose
// AddCandidate(new[] { lhsShape.Length - 2 }, new[] { rhsShape.Length - 1 }, new[] { Lane }, new[] { Lane }, transB: rhs is Const);
AddCandidate(PackKind.M, PackKind.N, transB: rhs is Const);

// pack A's m,k and B's k,n
// AddCandidate(new[] { lhsShape.Length - 2, lhsShape.Length - 1 }, new[] { rhsShape.Length - 2, rhsShape.Length - 1 }, new[] { Lane, Lane }, new[] { Lane, Lane }, transB: rhs is Const);
AddCandidate(PackKind.M | PackKind.K, PackKind.K | PackKind.N, transB: rhs is Const);

// pack A's m,k and B's k
// AddCandidate(new[] { lhsShape.Length - 2, lhsShape.Length - 1 }, new[] { rhsShape.Length - 2 }, new[] { Lane, Lane }, new[] { Lane });
AddCandidate(PackKind.M | PackKind.K, PackKind.K);

// pack A's k and B's k,n
// AddCandidate(new[] { lhsShape.Length - 1 }, new[] { rhsShape.Length - 2, rhsShape.Length - 1 }, new[] { Lane }, new[] { Lane, Lane });
AddCandidate(PackKind.K, PackKind.K | PackKind.N);
}

return rets;
Expand Down
43 changes: 43 additions & 0 deletions src/Native/include/nncase/ntt/arch/aarch64/primitive_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,47 @@ template <> struct outer_product<ntt::vector<float, 4>, ntt::vector<float, 4>> {
return result;
}
};

template <bool Acc>
struct mma<Acc, true, ntt::vector<float, 4, 4>, ntt::vector<float, 4, 4>,
ntt::vector<float, 4, 4>> {
ntt::vector<float, 4, 4>
operator()(const ntt::vector<float, 4, 4> &lhs,
const ntt::vector<float, 4, 4> &rhs,
const ntt::vector<float, 4, 4> &out) const noexcept {
ntt::vector<float, 4, 4> ret;

// c,n,m,lane => c = c + (m[lane] * n)
if (Acc){
ret(0) = vfmaq_laneq_f32(out(0), rhs(0), lhs(0), 0); // k = 0
ret(1) = vfmaq_laneq_f32(out(1), rhs(0), lhs(0), 1);
ret(2) = vfmaq_laneq_f32(out(2), rhs(0), lhs(0), 2);
ret(3) = vfmaq_laneq_f32(out(3), rhs(0), lhs(0), 3);
} else {
vector<float, 4> zero = vdupq_n_f32(0.f);
ret(0) = vfmaq_laneq_f32(zero, rhs(0), lhs(0), 0); // k = 0
ret(1) = vfmaq_laneq_f32(zero, rhs(0), lhs(0), 1);
ret(2) = vfmaq_laneq_f32(zero, rhs(0), lhs(0), 2);
ret(3) = vfmaq_laneq_f32(zero, rhs(0), lhs(0), 3);
}

ret(0) = vfmaq_laneq_f32(ret(0), rhs(1), lhs(1), 0); // k = 1
ret(1) = vfmaq_laneq_f32(ret(1), rhs(1), lhs(1), 1);
ret(2) = vfmaq_laneq_f32(ret(2), rhs(1), lhs(1), 2);
ret(3) = vfmaq_laneq_f32(ret(3), rhs(1), lhs(1), 3);

ret(0) = vfmaq_laneq_f32(ret(0), rhs(2), lhs(2), 0); // k = 2
ret(1) = vfmaq_laneq_f32(ret(1), rhs(2), lhs(2), 1);
ret(2) = vfmaq_laneq_f32(ret(2), rhs(2), lhs(2), 2);
ret(3) = vfmaq_laneq_f32(ret(3), rhs(2), lhs(2), 3);

ret(0) = vfmaq_laneq_f32(ret(0), rhs(3), lhs(3), 0); // k = 3
ret(1) = vfmaq_laneq_f32(ret(1), rhs(3), lhs(3), 1);
ret(2) = vfmaq_laneq_f32(ret(2), rhs(3), lhs(3), 2);
ret(3) = vfmaq_laneq_f32(ret(3), rhs(3), lhs(3), 3);

return ret;
}
};

} // namespace nncase::ntt::ops
2 changes: 1 addition & 1 deletion src/Native/include/nncase/ntt/arch/x86_64/primitive_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ template <> struct max<ntt::vector<float, 8>, ntt::vector<float, 8>> {
};

template <bool AccC>
struct mma<AccC, ntt::vector<float, 8, 8>, ntt::vector<float, 8, 8>,
struct mma<AccC, false, ntt::vector<float, 8, 8>, ntt::vector<float, 8, 8>,
ntt::vector<float, 8, 8>> {
ntt::vector<float, 8, 8>
operator()(const ntt::vector<float, 8, 8> &lhs,
Expand Down
23 changes: 18 additions & 5 deletions src/Native/include/nncase/ntt/kernels/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,29 @@ class matmul_impl<false, true, AccumulateC, TLhs, TRhs, TOut, LhsPackedAxes,
LhsPackedAxes::at(0) == TLhs::rank() - 2 &&
RhsPackedAxes::rank() == 1 &&
RhsPackedAxes::at(0) == TRhs::rank() - 2) {
static_assert(LhsPackedAxes::rank() != 1, "not support!");
auto value = ntt::plane_outer_product(lhs, rhs, K);
// for (size_t k = 1; k < K; k++) {
// auto value = ntt::outer_product(lhs, rhs);
// output = AccC ? output + value : value;
// }
}
// 3.3. pack MK & KN
// 3.3. pack [M,K]<m,k> & [N,K]<k,n>
else if constexpr (LhsPackedAxes::rank() == 2 &&
LhsPackedAxes::at(0) == TLhs::rank() - 2 &&
LhsPackedAxes::at(1) == TLhs::rank() - 1 &&
RhsPackedAxes::rank() == 2 &&
RhsPackedAxes::at(0) == TRhs::rank() - 1 &&
RhsPackedAxes::at(1) == TRhs::rank() - 2) {
static_assert(LhsPackedAxes::rank() != 2, "not support!");
output = ntt::mma<AccC, false>(*lhs++, *rhs++, output);
}
// 3.3. pack [M,K]<k,m> & [N,K]<k,n>
else if constexpr (LhsPackedAxes::rank() == 2 &&
LhsPackedAxes::at(0) == TLhs::rank() - 1 &&
LhsPackedAxes::at(1) == TLhs::rank() - 2 &&
RhsPackedAxes::rank() == 2 &&
RhsPackedAxes::at(0) == TRhs::rank() - 1 &&
RhsPackedAxes::at(1) == TRhs::rank() - 2) {
output = ntt::mma<AccC, true>(*lhs++, *rhs++, output);
}
// fall back
else {
Expand Down Expand Up @@ -246,7 +259,7 @@ class matmul_impl<false, false, AccumulateC, TLhs, TRhs, TOut, LhsPackedAxes,
{lhs}};
fixed_tensor_alike_t<TOutElem, 1, TOutElem::shape().at(0)>
output_2d{{output}};
output_2d = ntt::mma<AccC>(lhs_2d, rhs, output_2d);
output_2d = ntt::mma<AccC, false>(lhs_2d, rhs, output_2d);
output = output_2d(0);
}
// 3.3. pack MK & KN
Expand All @@ -256,7 +269,7 @@ class matmul_impl<false, false, AccumulateC, TLhs, TRhs, TOut, LhsPackedAxes,
RhsPackedAxes::rank() == 2 &&
RhsPackedAxes::at(0) == TRhs::rank() - 2 &&
RhsPackedAxes::at(1) == TRhs::rank() - 1) {
output = ntt::mma<AccC>(lhs, rhs, output);
output = ntt::mma<AccC, false>(lhs, rhs, output);
} else {
static_assert(sizeof(TLhsElem) == 0, "Unsupported packing.");
}
Expand Down
49 changes: 37 additions & 12 deletions src/Native/include/nncase/ntt/primitive_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ template <class T1, class T2> struct outer_product {
}
};

template <class T1, class T2> struct plane_outer_product {
constexpr auto operator()(const T1 &v1, const T2 &v2,
const size_t length) const noexcept;
};

/**
* @remarks mod is equivalent to fmod() function in C/C++/Python.
*/
Expand Down Expand Up @@ -232,7 +237,8 @@ template <class T1, class T2, class TResult> struct mul_add {
const TResult &v3) const noexcept;
};

template <bool AccC, IsFixedTensor T1, IsFixedTensor T2, IsFixedTensor TResult>
template <bool AccC, bool TransA, IsFixedTensor T1, IsFixedTensor T2,
IsFixedTensor TResult>
struct mma {
constexpr TResult operator()(const T1 &v1, const T2 &v2,
const TResult &v3) const noexcept;
Expand Down Expand Up @@ -288,6 +294,11 @@ NTT_DEFINE_BINARY_FUNC_IMPL(div)
NTT_DEFINE_BINARY_FUNC_IMPL(floor_mod)
NTT_DEFINE_BINARY_FUNC_IMPL(inner_product)
NTT_DEFINE_BINARY_FUNC_IMPL(outer_product)
template <IsTensorOrScalar T1, IsTensorOrScalar T2>
constexpr auto plane_outer_product(const T1 &v1, const T2 &v2,
const size_t length) noexcept {
return ops::plane_outer_product<T1, T2>()(v1, v2, length);
}
NTT_DEFINE_BINARY_FUNC_IMPL(mod)
NTT_DEFINE_BINARY_FUNC_IMPL(min)
NTT_DEFINE_BINARY_FUNC_IMPL(max)
Expand All @@ -304,9 +315,10 @@ constexpr TResult mul_add(const T1 &v1, const T2 &v2,
return ops::mul_add<T1, T2, TResult>()(v1, v2, v3);
}

template <bool AccC, IsFixedTensor T1, IsFixedTensor T2, IsFixedTensor TResult>
template <bool AccC, bool TransA, IsFixedTensor T1, IsFixedTensor T2,
IsFixedTensor TResult>
constexpr TResult mma(const T1 &v1, const T2 &v2, const TResult &v3) noexcept {
return ops::mma<AccC, T1, T2, TResult>()(v1, v2, v3);
return ops::mma<AccC, TransA, T1, T2, TResult>()(v1, v2, v3);
}

/**
Expand Down Expand Up @@ -416,19 +428,32 @@ mul_add<T1, T2, TResult>::operator()(const T1 &v1, const T2 &v2,
return v1 * v2 + v3;
}

template <bool AccC, IsFixedTensor T1, IsFixedTensor T2, IsFixedTensor TResult>
constexpr TResult
mma<AccC, T1, T2, TResult>::operator()(const T1 &lhs, const T2 &rhs,
const TResult &v3) const noexcept {
template <bool AccC, bool TransA, IsFixedTensor T1, IsFixedTensor T2,
IsFixedTensor TResult>
constexpr TResult mma<AccC, TransA, T1, T2, TResult>::operator()(
const T1 &lhs, const T2 &rhs, const TResult &v3) const noexcept {
static_assert(T1::rank() == T2::rank() && T2::rank() == TResult::rank() &&
TResult::rank() == 2,
"only support 2d mma");
TResult output = v3;
for (size_t m = 0; m < T1::shape().at(0); m++) {
for (size_t k = 0; k < T2::shape().at(0); k++) {
output(m) = (k != 0 || AccC)
? ntt::mul_add(lhs(m, k), rhs(k), output(m))
: ntt::mul(lhs(m, k), rhs(k));
if constexpr (TransA) {
// <k,m> @ <k,n>
if constexpr (AccC) {
output = ntt::outer_product(lhs(0), rhs(0)) + output;
} else {
output = ntt::outer_product(lhs(0), rhs(0));
}

for (size_t k = 1; k < T1::shape().at(0); k++) {
output = ntt::outer_product(lhs(k), rhs(k)) + output;
}
} else {
for (size_t m = 0; m < T1::shape().at(0); m++) {
for (size_t k = 0; k < T2::shape().at(0); k++) {
output(m) = (k != 0 || AccC)
? ntt::mul_add(lhs(m, k), rhs(k), output(m))
: ntt::mul(lhs(m, k), rhs(k));
}
}
}

Expand Down
51 changes: 50 additions & 1 deletion src/Native/src/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,14 +959,63 @@ int main() {
ntt::matmul<false, false, true>(
ta, packb, tc2, ntt::fixed_shape<>{}, ntt::fixed_shape<>{},
ntt::fixed_shape<0>{}, ntt::fixed_shape<>{});

ntt::tensor<float, ntt::fixed_shape<8, 8>> tc2unpack;
ntt::unpack<1>(tc2, tc2unpack);

ntt::apply(tc.shape(), [&]([[maybe_unused]] auto index) {
assert(tc2unpack(index) == tc(index));
});
}

// A[m,k]<m,k> @ B[n,k]<k,n>
{
ntt::tensor<ntt::vector<float, 4, 4>, ntt::fixed_shape<2, 1>>
packb;
ntt::pack<1, 0>(tranb, packb); // [n,k]<k,n>
ntt::tensor<ntt::vector<float, 4, 4>, ntt::fixed_shape<2, 1>>
packa;
// note actully a should pack as [m,k]<k,m>
ntt::pack<0, 1>(ta, packa); // [m,k]<m,k>
// [m,n]<m,n>
ntt::tensor<ntt::vector<float, 4, 4>, ntt::fixed_shape<2, 2>>
tc2;
ntt::matmul<false, false, true>(
packa, packb, tc2, ntt::fixed_shape<0, 1>{},
ntt::fixed_shape<>{}, ntt::fixed_shape<1, 0>{},
ntt::fixed_shape<>{});

ntt::tensor<float, ntt::fixed_shape<8, 8>> tc2unpack;
ntt::unpack<0, 1>(tc2, tc2unpack);

ntt::apply(tc.shape(), [&]([[maybe_unused]] auto index) {
assert(tc2unpack(index) == tc(index));
});
}

// A[m,k]<k,m> @ B[n,k]<k,n>
{
ntt::tensor<ntt::vector<float, 4, 4>, ntt::fixed_shape<2, 1>>
packb;
ntt::pack<1, 0>(tranb, packb); // [n,k]<k,n>
ntt::tensor<ntt::vector<float, 4, 4>, ntt::fixed_shape<2, 1>>
packa;
ntt::pack<1, 0>(ta, packa); // [m,k]<k,m>
// [m,n]<m,n>
ntt::tensor<ntt::vector<float, 4, 4>, ntt::fixed_shape<2, 2>>
tc2;
ntt::matmul<false, false, true>(
packa, packb, tc2, ntt::fixed_shape<1, 0>{},
ntt::fixed_shape<>{}, ntt::fixed_shape<1, 0>{},
ntt::fixed_shape<>{});

ntt::tensor<float, ntt::fixed_shape<8, 8>> tc2unpack;
ntt::unpack<0, 1>(tc2, tc2unpack);

ntt::apply(tc.shape(), [&]([[maybe_unused]] auto index) {
assert(tc2unpack(index) == tc(index));
});
}
}
}

Expand Down
Loading

0 comments on commit 46c948c

Please sign in to comment.