Skip to content

Commit

Permalink
feat(kernel): add AArch64 MLA int8 NCHW44 dense ConvBias 1x1 kernel
Browse files Browse the repository at this point in the history
GitOrigin-RevId: ab214072652f162e638d036b86824f81c15fcf99
  • Loading branch information
megvii-mge committed Apr 7, 2024
1 parent d01fbe7 commit 9f5e6e3
Show file tree
Hide file tree
Showing 6 changed files with 1,198 additions and 0 deletions.
23 changes: 23 additions & 0 deletions compiler/lib/KernelGen/Arm/Arm64/ConvKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,29 @@ class Conv1x1DotMk4 : public Arm64ConvImpl {
MatmulInt8DotM8N12MK4Kernel m_inner_gemm;
};

class Int8Conv1x1NCHW44 : public Arm64ConvImpl {
public:
std::string GetKernelSymbol(TContext* context) const override;
bool IsAvailable(TContext* context) const override;
//! kernel gen
std::string GetKernelBody(TContext* context) const override;
//! init gen
std::string GetInitBody(TContext* context) const override;
std::string GetWorkspaceBody(TContext* ctx) const override {
return GetWorkspaceBodyCondition(ctx, false);
}
std::string GetWorkspaceBodyAndJitExec(TContext* ctx) const override {
return GetWorkspaceBodyCondition(ctx, true);
}
std::vector<KernelObj> GetDependInternalSymbol(TContext* context) const override;

private:
std::string GetWorkspaceBodyCondition(TContext* ctx, bool jit) const;
bool need_temp_dst(TContext* ctx) const;
std::shared_ptr<TContext> GetInnerCtx(TContext* ctx) const;
MatmulInt8M4N4K16MK4Kernel m_inner_gemm;
};

class ConvBiasIm2colI8mmNCHW44 : public Arm64ConvImpl {
public:
std::string GetKernelSymbol(TContext* context) const override;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
#include <sstream>
#include <string>
#include "Arm/Arm64/Activation.h"
#include "Arm/Arm64/ConvKernel.h"
#include "Arm/Arm64/InternalKernel/InternalKernel.h"
#include "Utils/SymbolHelper.h"
#include "Utils/Utils.h"
#include "compiler/KernelGen/KernelGen.h"

using namespace megcc;
using namespace KernelGen;
using namespace Arm64;
using namespace ArmCommon;
namespace megcc {
namespace KernelGen {
namespace Arm64 {

bool Int8Conv1x1NCHW44::IsAvailable(TContext* ctx) const {
bool param_value_ok =
ctx->getAttrUInt("kernel_h") == 1 && ctx->getAttrUInt("kernel_w") == 1 &&
ctx->getAttrUInt("stride_h") == 1 && ctx->getAttrUInt("stride_w") == 1 &&
ctx->getAttrUInt("pad_h") == 0 && ctx->getAttrUInt("pad_w") == 0 &&
ctx->getAttrUInt("dilate_h") == 1 && ctx->getAttrUInt("dilate_w") == 1;
bool param_mode_ok = ctx->getAttrStr("sparse") == "DENSE" &&
ctx->getAttrStr("format") == "NCHW44" &&
ctx->getAttrStr("mode") == "CROSS_CORRELATION";
bool noline_ok = !ctx->haveAttr("nonlineMode") ||
ctx->getAttrStr("nonlineMode") == "IDENTITY" ||
ctx->getAttrStr("nonlineMode") == "RELU" ||
ctx->getAttrStr("nonlineMode") == "H_SWISH";

bool type_ok = is_qint8_conv_dtype(ctx);

bool layout_ok = ctx->getAttrOprand("operand:0").shape.size() == 5 &&
ctx->getAttrOprand("operand:0").shape[4] == 4;
bool bias_ok = !is_bias(ctx) || is_channel_broadcast_bias(ctx);
return param_value_ok && param_mode_ok && type_ok && noline_ok && layout_ok &&
bias_ok;
}

std::string Int8Conv1x1NCHW44::GetKernelSymbol(TContext* ctx) const {
std::stringstream extra_ss;
if (is_bias(ctx)) {
extra_ss << "_bias";
}
extra_ss << "_" << SymbolHelper::gen_io_str(ctx);
if (ctx->haveAttr("nonlineMode") && ctx->getAttrStr("nonlineMode") != "IDENTITY") {
extra_ss << "_" << ctx->getAttrStr("nonlineMode");
}
std::string name_temp =
"Arm64_kernel_conv2d_conv1x1_${format}_${kernel_h}x${kernel_w}_"
"${"
"sparse}_p${pad_h}x${pad_w}_s${stride_h}x${stride_w}_d${"
"dilate_h}x${dilate_w}${extra}";
return StringTemplate::StringTemplateArgs(ctx)
.add_ctx_int("kernel_h")
.add_ctx_int("kernel_w")
.add_ctx_str("format")
.add_ctx_str("sparse")
.add_ctx_int("pad_h")
.add_ctx_int("pad_w")
.add_ctx_int("stride_h")
.add_ctx_int("stride_w")
.add_ctx_int("dilate_h")
.add_ctx_int("dilate_w")
.add("extra", extra_ss.str())
.render(name_temp);
}

std::string Int8Conv1x1NCHW44::GetInitBody(TContext* ctx) const {
std::stringstream writer;
auto inner_ctx = GetInnerCtx(ctx);
writer << m_inner_gemm.GetPackASignature(inner_ctx.get()) << ";\n";
writer << m_inner_gemm.GetPackAWorkspaceSignature(inner_ctx.get()) << ";\n";
writer << GenCommonRet() << " " << GetInitSignature(ctx);
uint32_t nr_out_weight = 1;
std::string common_def = R"(
int PACK_C_SIZE = 4;
Tensor* in_weights = inputs[1];
int ymax = in_weights->layout.dims[0] * PACK_C_SIZE;
int kmax = in_weights->layout.dims[1] * PACK_C_SIZE;
int ldin = kmax * PACK_C_SIZE;
)";
std::string fill_weight_attr =
R"(
out_weights->layout.nr_dim = 1;
out_weights->layout.dims[0] = )" +
m_inner_gemm.GetPackAWorkspaceSymbol(inner_ctx.get()) +
R"((0, ymax, 0, kmax);
out_weights->layout.stride[0] = 1;
out_weights->dtype.type_enum=TinyNN_QINT8;
out_weights->name = in_weights->name;
out_weights->dtype.param.scale = in_weights->dtype.param.scale;
)";
std::string fill_weight_transform =
R"(
int8_t* outptr = out_weights->ptr;
int8_t* inptr = in_weights->ptr;
)" + m_inner_gemm.GetPackASymbol(inner_ctx.get()) +
"(outptr, inptr, ldin, 0, ymax, 0, kmax);";
writer << StringTemplate::render_init_body(
nr_out_weight, fill_weight_attr, fill_weight_transform, common_def);

return writer.str();
}

std::string Int8Conv1x1NCHW44::GetWorkspaceBodyCondition(
TContext* ctx, bool jit) const {
std::stringstream ss;
auto inner_ctx = GetInnerCtx(ctx);
if (jit) {
ss << m_inner_gemm.GetPackBWorkspaceBody(inner_ctx.get()) << ";\n";
} else {
ss << "extern " << m_inner_gemm.GetPackBWorkspaceSignature(inner_ctx.get())
<< ";\n";
}
ss << GenCommonRet() << " " << GetWorkspaceSignature(ctx);
std::string temp_dst_workspace;
if (need_temp_dst(ctx)) {
//! NOTE: conv1x1 src hw shape is equal to dst
temp_dst_workspace = R"(
const Layout flt_layout = inputs[1]->layout;
uint32_t oc = flt_layout.dims[0] * 4;
res += 128 + oc * hw * sizeof(int32_t);
)";
}
std::string workspace_temp =
R"({
TINYNN_ASSERT(workspace);
const Layout in_layout = inputs[0]->layout;
const uint32_t ic = in_layout.dims[1] * 4;
const uint32_t ih = in_layout.dims[2];
const uint32_t iw = in_layout.dims[3];
const uint32_t hw = ih * iw;
size_t res = ${packb_size_sym}(0, hw, 0, ic);
${temp_dst_workspace}
*workspace = res;
return TinyNN_SUCCESS;
})";
ss << StringTemplate::StringTemplateArgs()
.add("packb_size_sym",
m_inner_gemm.GetPackBWorkspaceSymbol(inner_ctx.get()))
.add("temp_dst_workspace", temp_dst_workspace)
.render(workspace_temp);

return ss.str();
}

std::vector<KernelObj> Int8Conv1x1NCHW44::GetDependInternalSymbol(TContext* ctx) const {
auto inner_ctx = GetInnerCtx(ctx);

return {
{m_inner_gemm.GetKernelSymbol(inner_ctx.get()),
m_inner_gemm.GetKernelBody(inner_ctx.get()),
m_inner_gemm.GetBodyGuardBegin(inner_ctx.get()),
m_inner_gemm.GetBodyGuardEnd(inner_ctx.get()),
m_inner_gemm.GetDependInternalSymbol(inner_ctx.get())}};
}

bool Int8Conv1x1NCHW44::need_temp_dst(TContext* ctx) const {
auto inner_ctx = GetInnerCtx(ctx);
return m_inner_gemm.need_post_process(inner_ctx.get());
}

std::shared_ptr<TContext> Int8Conv1x1NCHW44::GetInnerCtx(TContext* ctx) const {
auto inner_ctx = std::make_shared<CodeGenContext>();
if (ctx->haveAttr("nonlineMode")) {
inner_ctx->setAttr("nonlineMode", CCAttr(ctx->getAttrStr("nonlineMode")));
}
inner_ctx->setAttr("with_bias", ConvImpl::is_bias(ctx));
inner_ctx->setAttr("transposeA", false);
inner_ctx->setAttr("transposeB", false);
inner_ctx->setAttr("format", "MK4");
inner_ctx->setAttr("dtype", ctx->getAttrOprand("operand:0").dtype);
auto last_dtype = Utils::get_last_operand(ctx).dtype;
auto last_dtype_str = SymbolHelper::gen_valid_dtype(last_dtype);
inner_ctx->setAttr("last_dtype", last_dtype_str);
return inner_ctx;
}

std::string Int8Conv1x1NCHW44::GetKernelBody(TContext* ctx) const {
std::stringstream writer;
auto inner_ctx = GetInnerCtx(ctx);
writer << m_inner_gemm.GetPackBWorkspaceSignature(inner_ctx.get()) << ";\n";
writer << m_inner_gemm.GetNakedKernelSignature(inner_ctx.get()) << ";\n";
writer << m_inner_gemm.GetPackBSignature(inner_ctx.get()) << ";\n";
writer << GenCommonRet() << " " << GetKernelSignature(ctx);
std::string bias_ptr_str = is_bias(ctx) ? "inputs[2]->ptr;" : "0;";
std::string gen_temp_dst = "void* temp_dst = NULL;";
if (need_temp_dst(ctx)) {
gen_temp_dst = "void* temp_dst = (int8_t*) workspace_ptr + pack_b_align;";
}
auto last_dtype = Utils::get_last_operand(ctx).dtype;
auto last_dtype_str = SymbolHelper::gen_valid_dtype(last_dtype);
std::string dst_specifier = Utils::cvt_dtype_specifier(last_dtype_str);
writer << StringTemplate::StringTemplateArgs()
.add("bias_ptr_str", bias_ptr_str)
.add("packb_size_sym",
m_inner_gemm.GetPackBWorkspaceSymbol(inner_ctx.get()))
.add("packb_sym", m_inner_gemm.GetPackBSymbol(inner_ctx.get()))
.add("naked_kern_sym",
m_inner_gemm.GetNakedKernelSymbol(inner_ctx.get()))
.add("gen_temp_dst", gen_temp_dst)
.add("dst_specifier", dst_specifier)
.render(R"({
int8_t* input_data = inputs[0]->ptr;
${dst_specifier}* output_data = outputs[0]->ptr;
Layout in_layout = inputs[0]->layout;
Layout out_layout = outputs[0]->layout;
const int in_n = in_layout.dims[0];
const int in_c = in_layout.dims[1] * in_layout.dims[4];
const int in_h = in_layout.dims[2];
const int in_w = in_layout.dims[3];
const int PACK_C_SIZE = 4;
const float src_scale = inputs[0]->dtype.param.scale;
const float flt_scale = inputs[1]->dtype.param.scale;
const float dst_scale = outputs[0]->dtype.param.scale;
const float temp_scale = src_scale * flt_scale;
const float dst_scale_inv = 1.f / dst_scale;
const float scale = src_scale * flt_scale * dst_scale_inv;
const int out_c = out_layout.dims[1] * out_layout.dims[4];
const int out_h = out_layout.dims[2];
const int out_w = out_layout.dims[3];
const size_t N = out_h * out_w;
const int LDC = out_h * out_w * PACK_C_SIZE;
const int LDB = in_h * in_w * PACK_C_SIZE;
const size_t pack_b_size = ${packb_size_sym}(0, in_h * in_w, 0, in_c);
const size_t pack_b_align = (pack_b_size + 63) / 64 * 64;
void* workspace_ptr = workspace->ptr;
${gen_temp_dst}
for (int n_idx = 0; n_idx < in_n; ++n_idx) {
int8_t* weight_data = inputs[1]->ptr;
int32_t* bias_data = ${bias_ptr_str};
${packb_sym}(workspace_ptr, input_data, LDB, 0, in_h * in_w, 0, in_c);
${naked_kern_sym}(weight_data, workspace_ptr, output_data, LDC, out_c, N, in_c, bias_data, temp_dst, scale, temp_scale, dst_scale_inv);
input_data += in_c * in_h * in_w;
output_data += out_c * out_h * out_w;
}
return TinyNN_SUCCESS;
})");

return writer.str();
}

} // namespace Arm64
} // namespace KernelGen
} // namespace megcc

// vim: syntax=cpp.doxygen
Loading

0 comments on commit 9f5e6e3

Please sign in to comment.