Skip to content

Commit

Permalink
feat(kernel): support elemwise bias in Arm64 Fp32 NCHW44 Conv1x1
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 419e87f87bfc4ee60af99665bc022f0bffe6df55
  • Loading branch information
megvii-mge committed May 11, 2024
1 parent 6708597 commit 486dc65
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,21 @@ bool Conv1x1FloatMk4::IsAvailable(TContext* ctx) const {
ctx->getAttrOprand("operand:2").dtype == "f32";
bool layout_ok = ctx->getAttrOprand("operand:0").shape.size() == 5 &&
ctx->getAttrOprand("operand:0").shape[4] == 4;
bool bias_ok = !is_bias(ctx) || is_channel_broadcast_bias(ctx);
bool bias_ok =
!is_bias(ctx) || is_channel_broadcast_bias(ctx) || is_elemwise_bias(ctx);
return param_value_ok && param_mode_ok && type_ok && noline_ok && layout_ok &&
bias_ok;
}

std::string Conv1x1FloatMk4::GetKernelSymbol(TContext* ctx) const {
std::stringstream extra_ss;
if (is_bias(ctx)) {
extra_ss << "_bias";
if (is_channel_broadcast_bias(ctx))
extra_ss << "_channel_broadcast_bias";
else if (is_elemwise_bias(ctx))
extra_ss << "_elemwise_bias";
else
CC_ABORT << "only support channel broadcast and elemwise bias mode.";
}
if (ctx->haveAttr("nonlineMode") && ctx->getAttrStr("nonlineMode") != "IDENTITY") {
extra_ss << "_" << ctx->getAttrStr("nonlineMode");
Expand Down Expand Up @@ -163,7 +169,13 @@ std::shared_ptr<TContext> Conv1x1FloatMk4::GetInnerCtx(TContext* ctx) const {
if (ctx->haveAttr("nonlineMode")) {
inner_ctx->setAttr("nonlineMode", CCAttr(ctx->getAttrStr("nonlineMode")));
}
inner_ctx->setAttr("with_bias", ConvImpl::is_bias(ctx));
std::string bias_mode = "NO_BIAS";
if (is_channel_broadcast_bias(ctx)) {
bias_mode = "CHANNEL_BROADCAST_BIAS";
} else if (is_elemwise_bias(ctx)) {
bias_mode = "ELEMWISE_BIAS";
}
inner_ctx->setAttr("bias_mode", bias_mode);
inner_ctx->setAttr("transposeA", false);
inner_ctx->setAttr("transposeB", false);
inner_ctx->setAttr("format", "MK4");
Expand All @@ -177,9 +189,22 @@ std::string Conv1x1FloatMk4::GetKernelBody(TContext* ctx) const {
writer << m_inner_gemm.GetNakedKernelSignature(inner_ctx.get()) << ";\n";
writer << m_inner_gemm.GetPackBSignature(inner_ctx.get()) << ";\n";
writer << GenCommonRet() << " " << GetKernelSignature(ctx);
std::string bias_ptr_str = is_bias(ctx) ? "inputs[2]->ptr;" : "0;";
std::string elemwise_bias_init, channel_broadcast_bias_init;
if (is_channel_broadcast_bias(ctx)) {
channel_broadcast_bias_init = "bias_data = inputs[2]->ptr;";
} else if (is_elemwise_bias(ctx)) {
elemwise_bias_init = "bias_data = inputs[2]->ptr;";
}
std::string bias_offset = "0";
if (is_channel_broadcast_bias(ctx)) {
bias_offset = "ocpg";
} else if (is_elemwise_bias(ctx)) {
bias_offset = "ocpg * out_h * out_w";
}
writer << StringTemplate::StringTemplateArgs()
.add("bias_ptr_str", bias_ptr_str)
.add("elemwise_bias_init", elemwise_bias_init)
.add("channel_broadcast_bias_init", channel_broadcast_bias_init)
.add("bias_offset", bias_offset)
.add("packB", m_inner_gemm.GetPackBSymbol(inner_ctx.get()))
.add("kern", m_inner_gemm.GetNakedKernelSymbol(inner_ctx.get()))
.render(R"({
Expand Down Expand Up @@ -215,16 +240,18 @@ std::string Conv1x1FloatMk4::GetKernelBody(TContext* ctx) const {
const int LDB = in_h * in_w * PACK_C_SIZE;
void* workspace_ptr = workspace->ptr;
const float* bias_data = NULL;
${elemwise_bias_init}
for (int n_idx = 0; n_idx < in_n; ++n_idx) {
float* weight_data = inputs[1]->ptr;
float* bias_data = ${bias_ptr_str};
${channel_broadcast_bias_init}
for(int group_idx = 0; group_idx < group; ++group_idx){
${packB}(workspace_ptr, input_data, LDB, 0, in_h * in_w, 0, icpg);
${kern}(weight_data, workspace_ptr, output_data, LDC, ocpg, N, icpg, bias_data);
input_data += icpg * in_h * in_w;
output_data += ocpg * out_h * out_w;
weight_data += weight_layout.stride[0];
bias_data += ocpg;
bias_data += ${bias_offset};
}
}
return TinyNN_SUCCESS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ std::shared_ptr<TContext> ConvIm2colFloat::GetInnerCtx(TContext* ctx) const {
inner_ctx->setAttr("nonlineMode", CCAttr(ctx->getAttrStr("nonlineMode")));
}
inner_ctx->setAttr("with_bias", ConvImpl::is_bias(ctx));
inner_ctx->setAttr(
"bias_mode", ConvImpl::is_channel_broadcast_bias(ctx)
? "CHANNEL_BROADCAST_BIAS"
: "NO_BIAS");
inner_ctx->setAttr("transposeA", false);
inner_ctx->setAttr("transposeB", false);
inner_ctx->setAttr("dtype", "f32");
Expand Down
Loading

0 comments on commit 486dc65

Please sign in to comment.