diff --git a/compiler/lib/KernelGen/Arm/Arm64/ConvKernel/Fp32/Fp32Conv1x1Mk4M8N12.cpp b/compiler/lib/KernelGen/Arm/Arm64/ConvKernel/Fp32/Fp32Conv1x1Mk4M8N12.cpp index ccae049c..be75e71e 100644 --- a/compiler/lib/KernelGen/Arm/Arm64/ConvKernel/Fp32/Fp32Conv1x1Mk4M8N12.cpp +++ b/compiler/lib/KernelGen/Arm/Arm64/ConvKernel/Fp32/Fp32Conv1x1Mk4M8N12.cpp @@ -35,7 +35,8 @@ bool Conv1x1FloatMk4::IsAvailable(TContext* ctx) const { ctx->getAttrOprand("operand:2").dtype == "f32"; bool layout_ok = ctx->getAttrOprand("operand:0").shape.size() == 5 && ctx->getAttrOprand("operand:0").shape[4] == 4; - bool bias_ok = !is_bias(ctx) || is_channel_broadcast_bias(ctx); + bool bias_ok = + !is_bias(ctx) || is_channel_broadcast_bias(ctx) || is_elemwise_bias(ctx); return param_value_ok && param_mode_ok && type_ok && noline_ok && layout_ok && bias_ok; } @@ -43,7 +44,12 @@ bool Conv1x1FloatMk4::IsAvailable(TContext* ctx) const { std::string Conv1x1FloatMk4::GetKernelSymbol(TContext* ctx) const { std::stringstream extra_ss; if (is_bias(ctx)) { - extra_ss << "_bias"; + if (is_channel_broadcast_bias(ctx)) + extra_ss << "_channel_broadcast_bias"; + else if (is_elemwise_bias(ctx)) + extra_ss << "_elemwise_bias"; + else + CC_ABORT << "only support channel broadcast and elemwise bias mode."; } if (ctx->haveAttr("nonlineMode") && ctx->getAttrStr("nonlineMode") != "IDENTITY") { extra_ss << "_" << ctx->getAttrStr("nonlineMode"); @@ -163,7 +169,13 @@ std::shared_ptr Conv1x1FloatMk4::GetInnerCtx(TContext* ctx) const { if (ctx->haveAttr("nonlineMode")) { inner_ctx->setAttr("nonlineMode", CCAttr(ctx->getAttrStr("nonlineMode"))); } - inner_ctx->setAttr("with_bias", ConvImpl::is_bias(ctx)); + std::string bias_mode = "NO_BIAS"; + if (is_channel_broadcast_bias(ctx)) { + bias_mode = "CHANNEL_BROADCAST_BIAS"; + } else if (is_elemwise_bias(ctx)) { + bias_mode = "ELEMWISE_BIAS"; + } + inner_ctx->setAttr("bias_mode", bias_mode); inner_ctx->setAttr("transposeA", false); inner_ctx->setAttr("transposeB", false); inner_ctx->setAttr("format", "MK4"); @@ -177,9 +189,22 @@ std::string Conv1x1FloatMk4::GetKernelBody(TContext* ctx) const { writer << m_inner_gemm.GetNakedKernelSignature(inner_ctx.get()) << ";\n"; writer << m_inner_gemm.GetPackBSignature(inner_ctx.get()) << ";\n"; writer << GenCommonRet() << " " << GetKernelSignature(ctx); - std::string bias_ptr_str = is_bias(ctx) ? "inputs[2]->ptr;" : "0;"; + std::string elemwise_bias_init, channel_broadcast_bias_init; + if (is_channel_broadcast_bias(ctx)) { + channel_broadcast_bias_init = "bias_data = inputs[2]->ptr;"; + } else if (is_elemwise_bias(ctx)) { + elemwise_bias_init = "bias_data = inputs[2]->ptr;"; + } + std::string bias_offset = "0"; + if (is_channel_broadcast_bias(ctx)) { + bias_offset = "ocpg"; + } else if (is_elemwise_bias(ctx)) { + bias_offset = "ocpg * out_h * out_w"; + } writer << StringTemplate::StringTemplateArgs() - .add("bias_ptr_str", bias_ptr_str) + .add("elemwise_bias_init", elemwise_bias_init) + .add("channel_broadcast_bias_init", channel_broadcast_bias_init) + .add("bias_offset", bias_offset) .add("packB", m_inner_gemm.GetPackBSymbol(inner_ctx.get())) .add("kern", m_inner_gemm.GetNakedKernelSymbol(inner_ctx.get())) .render(R"({ @@ -215,16 +240,18 @@ std::string Conv1x1FloatMk4::GetKernelBody(TContext* ctx) const { const int LDB = in_h * in_w * PACK_C_SIZE; void* workspace_ptr = workspace->ptr; + const float* bias_data = NULL; + ${elemwise_bias_init} for (int n_idx = 0; n_idx < in_n; ++n_idx) { float* weight_data = inputs[1]->ptr; - float* bias_data = ${bias_ptr_str}; + ${channel_broadcast_bias_init} for(int group_idx = 0; group_idx < group; ++group_idx){ ${packB}(workspace_ptr, input_data, LDB, 0, in_h * in_w, 0, icpg); ${kern}(weight_data, workspace_ptr, output_data, LDC, ocpg, N, icpg, bias_data); input_data += icpg * in_h * in_w; output_data += ocpg * out_h * out_w; weight_data += weight_layout.stride[0]; - bias_data += ocpg; + bias_data += ${bias_offset}; } } return TinyNN_SUCCESS; diff --git a/compiler/lib/KernelGen/Arm/Arm64/ConvKernel/Fp32/Fp32Im2col.cpp b/compiler/lib/KernelGen/Arm/Arm64/ConvKernel/Fp32/Fp32Im2col.cpp index 2bd0a055..f053db11 100644 --- a/compiler/lib/KernelGen/Arm/Arm64/ConvKernel/Fp32/Fp32Im2col.cpp +++ b/compiler/lib/KernelGen/Arm/Arm64/ConvKernel/Fp32/Fp32Im2col.cpp @@ -261,6 +261,10 @@ std::shared_ptr ConvIm2colFloat::GetInnerCtx(TContext* ctx) const { inner_ctx->setAttr("nonlineMode", CCAttr(ctx->getAttrStr("nonlineMode"))); } inner_ctx->setAttr("with_bias", ConvImpl::is_bias(ctx)); + inner_ctx->setAttr( + "bias_mode", ConvImpl::is_channel_broadcast_bias(ctx) + ? "CHANNEL_BROADCAST_BIAS" + : "NO_BIAS"); inner_ctx->setAttr("transposeA", false); inner_ctx->setAttr("transposeB", false); inner_ctx->setAttr("dtype", "f32"); diff --git a/compiler/lib/KernelGen/Arm/Arm64/InternalKernel/Fp32M8N12K4Matmul.cpp b/compiler/lib/KernelGen/Arm/Arm64/InternalKernel/Fp32M8N12K4Matmul.cpp index d9acc974..26f7b9c1 100644 --- a/compiler/lib/KernelGen/Arm/Arm64/InternalKernel/Fp32M8N12K4Matmul.cpp +++ b/compiler/lib/KernelGen/Arm/Arm64/InternalKernel/Fp32M8N12K4Matmul.cpp @@ -76,7 +76,6 @@ static std::string kern_8x12(TContext* ctx) { auto nonline_mode = ctx->haveAttr("nonlineMode") ? ctx->getAttrStr("nonlineMode") : "IDENTITY"; auto activation_gen = create_activation_gener(nonline_mode); - bool with_bias = ctx->getAttrBool("with_bias"); std::stringstream writer; //! kern_8x12 @@ -115,15 +114,17 @@ static std::string kern_8x12(TContext* ctx) { const float* b_ptr = packB; float* output0 = output; float* output1 = output0 + LDC; + const float* bias_ptr0 = bias_ptr; + const float* bias_ptr1 = bias_ptr0 + LDC; int oddk = (K & 1); K = ((K + 1) / 2) - 1; asm volatile()"; //! if convolution with bias - if (with_bias) { + if (ctx->getAttrStr("bias_mode") == "CHANNEL_BROADCAST_BIAS") { writer << R"( - "ld1 {v6.4s}, [%[bias_ptr]], #16\n" + "ld1 {v6.4s}, [%[bias_ptr0]], #16\n" "mov v8.16b, v6.16b \n" "mov v9.16b, v6.16b \n" "mov v10.16b, v6.16b \n" @@ -136,7 +137,7 @@ static std::string kern_8x12(TContext* ctx) { "mov v15.16b, v6.16b \n" "ld1 {v2.4s}, [%[b_ptr]], #16\n" "mov v16.16b, v6.16b \n" - "ld1 {v7.4s}, [%[bias_ptr]], #16\n" + "ld1 {v7.4s}, [%[bias_ptr0]], #16\n" "mov v17.16b, v6.16b \n" "mov v18.16b, v6.16b \n" "mov v19.16b, v6.16b \n" @@ -156,7 +157,38 @@ static std::string kern_8x12(TContext* ctx) { "mov v30.16b, v7.16b \n" "mov v31.16b, v7.16b \n")"; - //! if convolution without bias + } else if (ctx->getAttrStr("bias_mode") == "ELEMWISE_BIAS") { + writer << R"( + "ld1 {v8.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v9.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v10.4s}, [%[bias_ptr0]], #16\n" + "prfm pstl1keep, [%[output0]]\n" + "ld1 {v11.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v12.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v13.4s}, [%[bias_ptr0]], #16\n" + "prfm pstl1keep, [%[output1]]\n" + "ld1 {v14.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v15.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "ld1 {v16.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v17.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v18.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v19.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v20.4s}, [%[bias_ptr1]], #16\n" + "ld1 {v3.4s}, [%[b_ptr]], #16\n" + "ld1 {v21.4s}, [%[bias_ptr1]], #16\n" + "ld1 {v22.4s}, [%[bias_ptr1]], #16\n" + "ld1 {v23.4s}, [%[bias_ptr1]], #16\n" + "ld1 {v4.4s}, [%[b_ptr]], #16\n" + "ld1 {v24.4s}, [%[bias_ptr1]], #16\n" + "ld1 {v25.4s}, [%[bias_ptr1]], #16\n" + "ld1 {v26.4s}, [%[bias_ptr1]], #16\n" + "ld1 {v27.4s}, [%[bias_ptr1]], #16\n" + "ld1 {v28.4s}, [%[bias_ptr1]], #16\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v29.4s}, [%[bias_ptr1]], #16\n" + "ld1 {v30.4s}, [%[bias_ptr1]], #16\n" + "ld1 {v31.4s}, [%[bias_ptr1]], #16\n")"; } else { writer << R"( "eor v8.16b, v8.16b, v8.16b \n" @@ -438,7 +470,7 @@ static std::string kern_8x12(TContext* ctx) { "6:\n" : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), - [ bias_ptr ] "+r"(bias_ptr), [ oddk ] "+r"(oddk), + [ bias_ptr0 ] "+r"(bias_ptr0), [ bias_ptr1 ] "+r"(bias_ptr1), [ oddk ] "+r"(oddk), [ output0 ] "+r"(output0), [ output1 ] "+r"(output1) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", @@ -459,7 +491,6 @@ static std::string kern_4x12(TContext* ctx) { auto nonline_mode = ctx->haveAttr("nonlineMode") ? ctx->getAttrStr("nonlineMode") : "IDENTITY"; auto activation_gen = create_activation_gener(nonline_mode); - bool with_bias = ctx->getAttrBool("with_bias"); std::stringstream writer; writer << R"(static inline void kern_4x12_bias_relu(const float* packA, const float* packB, int K, float* output, int LDC, const float* bias_ptr) { @@ -473,7 +504,7 @@ static std::string kern_4x12(TContext* ctx) { asm volatile()"; //! if convolution with bias - if (with_bias) { + if (ctx->getAttrStr("bias_mode") == "CHANNEL_BROADCAST_BIAS") { writer << R"( "ld1 {v6.4s}, [%[bias_ptr]], #16\n" "mov v8.16b, v6.16b \n" @@ -492,6 +523,24 @@ static std::string kern_4x12(TContext* ctx) { "mov v18.16b, v6.16b \n" "mov v19.16b, v6.16b \n" )"; + } else if (ctx->getAttrStr("bias_mode") == "ELEMWISE_BIAS") { + writer << R"( + "ld1 {v8.4s}, [%[bias_ptr]], #16\n" + "ld1 {v9.4s}, [%[bias_ptr]], #16\n" + "ld1 {v10.4s}, [%[bias_ptr]], #16\n" + "prfm pstl1keep, [%[output0]]\n" + "ld1 {v11.4s}, [%[bias_ptr]], #16\n" + "ld1 {v12.4s}, [%[bias_ptr]], #16\n" + "ld1 {v13.4s}, [%[bias_ptr]], #16\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "ld1 {v14.4s}, [%[bias_ptr]], #16\n" + "ld1 {v15.4s}, [%[bias_ptr]], #16\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v18.4s}, [%[bias_ptr]], #16\n" + "ld1 {v19.4s}, [%[bias_ptr]], #16\n" + )"; } else { //! if convolution without bias writer << R"( @@ -639,7 +688,6 @@ static std::string kern_8x4(TContext* ctx) { auto nonline_mode = ctx->haveAttr("nonlineMode") ? ctx->getAttrStr("nonlineMode") : "IDENTITY"; auto activation_gen = create_activation_gener(nonline_mode); - bool with_bias = ctx->getAttrBool("with_bias"); std::stringstream writer; //! kern_8x4 @@ -677,6 +725,8 @@ static std::string kern_8x4(TContext* ctx) { const float* b_ptr = packB; float* output0 = output; float* output1 = output0 + LDC; + const float* bias_ptr0 = bias_ptr; + const float* bias_ptr1 = bias_ptr0 + LDC; int oddk = (K & 1); K = ((K + 1) / 2) - 1; @@ -707,10 +757,10 @@ static std::string kern_8x4(TContext* ctx) { //clang-format on asm volatile()"; - if (with_bias) { + if (ctx->getAttrStr("bias_mode") == "CHANNEL_BROADCAST_BIAS") { writer << R"( - "ld1 {v30.4s}, [%[bias_ptr]], #16\n" - "ld1 {v31.4s}, [%[bias_ptr]], #16\n" + "ld1 {v30.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v31.4s}, [%[bias_ptr0]], #16\n" "mov v8.16b, v30.16b \n" "mov v9.16b, v30.16b \n" "ld1 {v0.4s}, [%[a_ptr]], #16\n" @@ -723,6 +773,20 @@ static std::string kern_8x4(TContext* ctx) { "mov v14.16b, v31.16b \n" "ld1 {v2.4s}, [%[b_ptr]], #16\n" "mov v15.16b, v31.16b \n")"; + } else if (ctx->getAttrStr("bias_mode") == "ELEMWISE_BIAS") { + writer << R"( + "ld1 {v8.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v9.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v10.4s}, [%[bias_ptr0]], #16\n" + "prfm pstl1keep, [%[output0]]\n" + "ld1 {v11.4s}, [%[bias_ptr0]], #16\n" + "ld1 {v12.4s}, [%[bias_ptr1]], #16\n" + "prfm pstl1keep, [%[output1]]\n" + "ld1 {v13.4s}, [%[bias_ptr1]], #16\n" + "ld1 {v14.4s}, [%[bias_ptr1]], #16\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "ld1 {v15.4s}, [%[bias_ptr1]], #16\n")"; } else { writer << R"( "eor v8.16b, v8.16b, v8.16b \n" @@ -818,7 +882,7 @@ static std::string kern_8x4(TContext* ctx) { STORE_C : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), - [ bias_ptr ] "+r"(bias_ptr), [ oddk ] "+r"(oddk), + [ bias_ptr0 ] "+r"(bias_ptr0), [ bias_ptr1 ] "+r"(bias_ptr1), [ oddk ] "+r"(oddk), [ output0 ] "+r"(output0), [ output1 ] "+r"(output1), [ n_remain ] "+r"(n_remain) : @@ -839,7 +903,6 @@ static std::string kern_4x4(TContext* ctx) { auto nonline_mode = ctx->haveAttr("nonlineMode") ? ctx->getAttrStr("nonlineMode") : "IDENTITY"; auto activation_gen = create_activation_gener(nonline_mode); - bool with_bias = ctx->getAttrBool("with_bias"); std::stringstream writer; //! kern_4x4 writer << R"( @@ -898,7 +961,7 @@ static std::string kern_4x4(TContext* ctx) { //clang-format on asm volatile( )"; - if (with_bias) { + if (ctx->getAttrStr("bias_mode") == "CHANNEL_BROADCAST_BIAS") { writer << R"( // load accumulator C "ld1 {v30.4s}, [%[bias_ptr]], #16\n" @@ -908,6 +971,15 @@ static std::string kern_4x4(TContext* ctx) { "ld1 {v0.4s}, [%[a_ptr]], #16\n" "mov v10.16b, v30.16b \n" "mov v11.16b, v30.16b \n")"; + } else if (ctx->getAttrStr("bias_mode") == "ELEMWISE_BIAS") { + writer << R"( + // load accumulator C + "ld1 {v8.4s}, [%[bias_ptr]], #16\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "ld1 {v9.4s}, [%[bias_ptr]], #16\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v10.4s}, [%[bias_ptr]], #16\n" + "ld1 {v11.4s}, [%[bias_ptr]], #16\n")"; } else { writer << R"( "eor v8.16b, v8.16b, v8.16b \n" @@ -1107,6 +1179,21 @@ std::string gen_kernel( const std::string& sig, TContext* ctx, const std::string& postprocess_call, const std::string& preset_str = "") { auto post_process_strs = gen_postprocess_inline(ctx); + std::string channel_broadcast_bias_init = + ctx->getAttrStr("bias_mode") == "CHANNEL_BROADCAST_BIAS" + ? "_bias_ptr = bias_ptr + m;" + : ""; + std::string elemwise_bias_init = ctx->getAttrStr("bias_mode") == "ELEMWISE_BIAS" + ? "_bias_ptr = bias_ptr + m * N;" + : ""; + std::string elemwise_bias_update_n12 = + ctx->getAttrStr("bias_mode") == "ELEMWISE_BIAS" + ? "_bias_ptr += n_block * pack_mk;" + : ""; + std::string elemwise_bias_update_n4 = + ctx->getAttrStr("bias_mode") == "ELEMWISE_BIAS" + ? "_bias_ptr += 4 * pack_mk;" + : ""; std::string keren_body = R"( ${kernel_sig}{ @@ -1118,46 +1205,54 @@ std::string gen_kernel( const int K12 = K * 12; const int K8 = K * 8; const int K4 = K * 4; - size_t m = 0; + size_t m = 0; + const float* _bias_ptr = NULL; for (; m + m_block <= M; m += m_block) { float* output = C + (m / pack_mk * LDC); + ${channel_broadcast_bias_init} + ${elemwise_bias_init} size_t n = 0; const float* cur_pack_b = pack_b; for (; n + n_block <= N; n += n_block) { kern_8x12_bias_relu(pack_a, cur_pack_b, K, output, LDC, - bias_ptr); + _bias_ptr); output += n_block * pack_mk; + ${elemwise_bias_update_n12}; cur_pack_b += K12; } for (; n < N; n += 4) { kern_8x4_bias_relu(pack_a, cur_pack_b, K, output, LDC, - bias_ptr, N - n > 4 ? 4 : N - n); + _bias_ptr, N - n > 4 ? 4 : N - n); output += 4 * pack_mk; + ${elemwise_bias_update_n4}; cur_pack_b += K4; } pack_a += K8; - bias_ptr += m_block; } for (; m < M; m += m_block_4) { float* output = C + (m / pack_mk * LDC); + ${channel_broadcast_bias_init} + ${elemwise_bias_init} + size_t n = 0; const float* cur_pack_b = pack_b; for (; n + n_block - 1 < N; n += n_block) { kern_4x12_bias_relu(pack_a, cur_pack_b, K, output, LDC, - bias_ptr); + _bias_ptr); output += n_block * pack_mk; + ${elemwise_bias_update_n12}; cur_pack_b += K12; } for (; n < N; n += 4) { kern_4x4_bias_relu(pack_a, cur_pack_b, K, output, LDC, - bias_ptr, N - n > 4 ? 4 : N - n); + _bias_ptr, N - n > 4 ? 4 : N - n); output += 4 * pack_mk; + ${elemwise_bias_update_n4}; cur_pack_b += K4; } pack_a += K4; - bias_ptr += m_block_4; } ${postprocess_call} } @@ -1166,6 +1261,10 @@ std::string gen_kernel( .add("postprocess_call", postprocess_call) .add("preset_str", preset_str) .add("kernel_sig", sig) + .add("channel_broadcast_bias_init", channel_broadcast_bias_init) + .add("elemwise_bias_init", elemwise_bias_init) + .add("elemwise_bias_update_n12", elemwise_bias_update_n12) + .add("elemwise_bias_update_n4", elemwise_bias_update_n4) .render(keren_body); } @@ -1174,9 +1273,7 @@ std::string gen_kernel( std::string MatmulM8N12MK4Kernel::GetKernelSymbol(TContext* ctx) const { std::stringstream ss; ss << "Arm64_fp32_m8_n12_mk4_matmul"; - if (ctx->getAttrBool("with_bias")) { - ss << "_bias"; - } + ss << "_" << ctx->getAttrStr("bias_mode"); if (ctx->haveAttr("nonlineMode") && ctx->getAttrStr("nonlineMode") != "IDENTITY") { ss << "_" << ctx->getAttrStr("nonlineMode"); } diff --git a/compiler/lib/KernelGen/Arm/Arm64/MatMulKernel/Fp32MatMulM8N12K4.cpp b/compiler/lib/KernelGen/Arm/Arm64/MatMulKernel/Fp32MatMulM8N12K4.cpp index ae44eb4c..6f419ae1 100644 --- a/compiler/lib/KernelGen/Arm/Arm64/MatMulKernel/Fp32MatMulM8N12K4.cpp +++ b/compiler/lib/KernelGen/Arm/Arm64/MatMulKernel/Fp32MatMulM8N12K4.cpp @@ -10,7 +10,7 @@ using namespace Arm64; std::shared_ptr Fp32MatMulM8N12K4::GetInnerCtx(TContext* ctx) const { auto inner_ctx = std::make_shared(); inner_ctx->setAttr("format", "MK4"); - inner_ctx->setAttr("with_bias", false); + inner_ctx->setAttr("bias_mode", "NO_BIAS"); inner_ctx->setAttr("transposeA", false); inner_ctx->setAttr("transposeB", false); inner_ctx->setAttr("dtype", "f32"); diff --git a/compiler/lib/KernelGen/Common/ConvKernel.h b/compiler/lib/KernelGen/Common/ConvKernel.h index 43420c64..ce514b2e 100644 --- a/compiler/lib/KernelGen/Common/ConvKernel.h +++ b/compiler/lib/KernelGen/Common/ConvKernel.h @@ -19,6 +19,21 @@ class ConvImpl : public KernelFunc { } return false; } + static bool is_elemwise_bias(TContext* ctx) { + if (is_bias(ctx)) { + CCOperand bias = ctx->getAttrOprand("operand:2"); + CCOperand dst = Utils::get_last_operand(ctx); + if (bias.shape.size() != dst.shape.size()) + return false; + size_t len = bias.shape.size(); + for (size_t i = 0; i < len; ++i) { + if (bias.shape[i] != dst.shape[i]) + return false; + } + return true; + } + return false; + } static bool is_no_pad(TContext* ctx) { auto pad_h = ctx->getAttrInt("pad_h"); auto pad_w = ctx->getAttrInt("pad_w"); diff --git a/compiler/test/kernel/opr/arm/conv.cpp b/compiler/test/kernel/opr/arm/conv.cpp index a1fb6214..42095854 100644 --- a/compiler/test/kernel/opr/arm/conv.cpp +++ b/compiler/test/kernel/opr/arm/conv.cpp @@ -333,7 +333,10 @@ TEST(AARCH64, ConvBias1x1NCHW44) { ConvBiasForward::Param::NonlineMode::H_SWISH}) { param.nonlineMode = noline; checker.set_param(param); + checker.execs({{1, 3, 1, 1, 4}, {5, 3, 1, 1, 4, 4}, {1, 5, 1, 1, 4}, {}, {}}); checker.execs({{2, 3, 5, 11, 4}, {5, 3, 1, 1, 4, 4}, {1, 5, 1, 1, 4}, {}, {}}); + checker.execs({{2, 3, 5, 11, 4}, {5, 3, 1, 1, 4, 4}, {2, 5, 5, 11, 4}, {}, {}}); + checker.execs({{2, 3, 5, 11, 4}, {5, 3, 1, 1, 4, 4}, {}, {}, {}}); } param.sparse = ConvolutionForward::Param::Sparse::GROUP; @@ -345,6 +348,8 @@ TEST(AARCH64, ConvBias1x1NCHW44) { checker.set_param(param); checker.execs( {{2, 6, 17, 19, 4}, {2, 4, 3, 1, 1, 4, 4}, {1, 8, 1, 1, 4}, {}, {}}); + checker.execs( + {{2, 6, 17, 19, 4}, {2, 4, 3, 1, 1, 4, 4}, {2, 8, 17, 19, 4}, {}, {}}); } } diff --git a/script/build_and_test_not_standard_os.sh b/script/build_and_test_not_standard_os.sh index b9f3506b..135a98d6 100755 --- a/script/build_and_test_not_standard_os.sh +++ b/script/build_and_test_not_standard_os.sh @@ -28,7 +28,7 @@ cmake --build "$MEGCC_BUILD_DIR" -j$(nproc) --target mgb-to-tinynn --target mgb- function check_key_words() { #elf self mangle words, we do not care!! - white_list="@MEGW mgb1 5Mbg6 MGBi O:MgBnWk Yr]< 4emUi0B >HMgE kMEG RmEg MbGV4 MEgIy @MEg mGe#S BMgb MGB( mBg: MBgr8C A&mGB mEg; mGb>/ mEg= .strtab .shstrtab A=MgE= mgb=g MGe= g=MgE MGE< 8