From 6c0b6a967d7e45a58e04ca43c422f0e700726ef2 Mon Sep 17 00:00:00 2001 From: kqyhappy <2037278892@qq.com> Date: Sun, 22 Sep 2024 09:17:49 +0000 Subject: [PATCH] fix(kernel): fix matmul kernel --- .../WebAssembly/InternalKernel/Fp32MatMulM4N12.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/compiler/lib/KernelGen/WebAssembly/InternalKernel/Fp32MatMulM4N12.cpp b/compiler/lib/KernelGen/WebAssembly/InternalKernel/Fp32MatMulM4N12.cpp index f0936f74..c17c1205 100644 --- a/compiler/lib/KernelGen/WebAssembly/InternalKernel/Fp32MatMulM4N12.cpp +++ b/compiler/lib/KernelGen/WebAssembly/InternalKernel/Fp32MatMulM4N12.cpp @@ -51,7 +51,7 @@ static inline void interleave_helper( static inline void interleave_1( const float* inptr0, float* outptr, int unroll_k, int ksize, float val) { for (int k = 0; k < ksize; k += unroll_k) { - int size = min(unroll_k, ksize - k); + int size = unroll_k > (ksize - k)? (ksize - k) : unroll_k; interleave_helper(inptr0, outptr, unroll_k, size, val); inptr0 += size;outptr+=unroll_k; } @@ -61,7 +61,7 @@ static inline void interleave_4( const float* inptr0, const float* inptr1, const float* inptr2, const float* inptr3, float* outptr, int unroll_k, int ksize, float val) { for (int k = 0; k < ksize; k += unroll_k) { - int size = min(unroll_k, ksize - k); + int size = unroll_k > (ksize - k)? (ksize - k) : unroll_k; interleave_helper(inptr0, outptr, unroll_k, size, val); inptr0 += size;outptr+=unroll_k; interleave_helper(inptr1, outptr, unroll_k, size, val); @@ -413,7 +413,7 @@ static std::string kern_4x4(TContext* crx) { } std::string pack_A_n(const std::string kern_sym, TContext* ctx) { - return "void" + kern_sym + "_packa_n" + + return "void " + kern_sym + "_packa_n" + WebAssemblyMatmulInternal::GenPackACall(ctx) + R"({ float zerobuff[4]; @@ -586,15 +586,15 @@ std::string gen_kernel( const float* cur_pack_b = pack_b; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { kern_4x12(pack_a, cur_pack_b, K, output, LDC, - min(M - m, 4), bias_ptr); + (M - m) > 4 ? 4 : (M - m), bias_ptr); output += B_INTERLEAVE; cur_pack_b += K12; } for (; n < N; n += 4) { kern_4x4(pack_a, cur_pack_b, K, output, LDC, - min(M - m, 4), - min(N - n, 4), bias_ptr); + (M - m) > 4 ? 4 : (M - m), + (N - n) > 4 ? 4 : (N - n), bias_ptr); output += 4; cur_pack_b += K4; }