Skip to content

Commit

Permalink
feat(kernel): add GI Float16 NCHW88 Resize kernel
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 3b9b11abf0ccafd741dcbfcc374dedd0727f5bb8
  • Loading branch information
megvii-mge committed May 10, 2024
1 parent b37dead commit 4bb59d2
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 28 deletions.
116 changes: 93 additions & 23 deletions compiler/lib/KernelGen/Common/Resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ResizeHelper {
static std::string GenCoordHelper(
const std::string& imode, const std::string& specifier) {
CC_ASSERT(imode == "LINEAR");
CC_ASSERT(specifier == "float");
CC_ASSERT(specifier == "float" || specifier == "gi_float16_t");
std::string body = R"(
#include <math.h>
static inline void get_coord(float scale, int size, int idx, float* ah0, int* ih0, float* ah1, int* ih1){
Expand All @@ -19,27 +19,29 @@ class ResizeHelper {
*ih0 = 0;
*ah1 = 0.f;
*ih1 = 0;
}
float alpha = (idx + 0.5f) / scale - 0.5f;
int origin_idx = (int)(floorf(alpha));
alpha -= origin_idx;
} else {
float alpha = (idx + 0.5f) / scale - 0.5f;
int origin_idx = (int)(floorf(alpha));
alpha -= origin_idx;
if (origin_idx < 0) {
origin_idx = 0;
alpha = 0;
} else if (origin_idx + 1 >= size) {
origin_idx = size - 2;
alpha = 1;
if (origin_idx < 0) {
origin_idx = 0;
alpha = 0;
} else if (origin_idx + 1 >= size) {
origin_idx = size - 2;
alpha = 1;
}
*ah0 = 1 - alpha;
*ih0 = origin_idx;
*ah1 = alpha;
*ih1 = origin_idx + 1;
}
*ah0 = 1 - alpha;
*ih0 = origin_idx;
*ah1 = alpha;
*ih1 = origin_idx + 1;
}
)";
return body;
}
static std::string GenNormImpl(const std::string& format) {
static std::string GenNormImpl(
const std::string& format, const std::string& specifier = "float") {
std::string ret;
ret = R"(
{
Expand All @@ -62,7 +64,13 @@ class ResizeHelper {
int out_batch_stride = C * oc_stride;
int ic_stride = IH * IW;
int in_batch_stride = C * ic_stride;
${core}
}
)";

std::string core;
if (specifier == "float") {
core = R"(
rep(n, N){
rep(c, C) {
rep(oh, OH) {
Expand All @@ -87,10 +95,56 @@ class ResizeHelper {
sptr += in_batch_stride;
dptr += out_batch_stride;
}
}
)";
} else if (specifier == "gi_float16_t") {
core = R"(
C /= 8;
rep(n, N){
rep(c, C) {
rep(oh, OH) {
int ih0 = ih0_cache[oh];
int ih1 = ih1_cache[oh];
gi_float16_t ah0 = (gi_float16_t)(ah0_cache[oh]);
gi_float16_t ah1 = (gi_float16_t)(ah1_cache[oh]);
GI_FLOAT16_t v_ah0 = GiBroadcastFloat16(ah0);
GI_FLOAT16_t v_ah1 = GiBroadcastFloat16(ah1);
rep(ow, OW) {
int iw0 = iw0_cache[ow];
int iw1 = iw1_cache[ow];
gi_float16_t aw0 = (gi_float16_t)(aw0_cache[ow]);
gi_float16_t aw1 = (gi_float16_t)(aw1_cache[ow]);
GI_FLOAT16_t v_aw0 = GiBroadcastFloat16(aw0);
GI_FLOAT16_t v_aw1 = GiBroadcastFloat16(aw1);
return ret;
GI_FLOAT16_t v_00 = GiLoadFloat16(sptr + get_offset(ih0, iw0, c, IH, IW, C));
GI_FLOAT16_t v_weight_00 = GiMultiplyFloat16(v_ah0, v_aw0);
GI_FLOAT16_t v_01 = GiLoadFloat16(sptr + get_offset(ih0, iw1, c, IH, IW, C));
GI_FLOAT16_t v_weight_01 = GiMultiplyFloat16(v_ah0, v_aw1);
GI_FLOAT16_t v_10 = GiLoadFloat16(sptr + get_offset(ih1, iw0, c, IH, IW, C));
GI_FLOAT16_t v_weight_10 = GiMultiplyFloat16(v_ah1, v_aw0);
GI_FLOAT16_t v_11 = GiLoadFloat16(sptr + get_offset(ih1, iw1, c, IH, IW, C));
GI_FLOAT16_t v_weight_11 = GiMultiplyFloat16(v_ah1, v_aw1);
GI_FLOAT16_t vr_00 = GiMultiplyFloat16(v_00, v_weight_00);
GI_FLOAT16_t vr_01 = GiMultiplyFloat16(v_01, v_weight_01);
GI_FLOAT16_t vr = GiAddFloat16(vr_00, vr_01);
GI_FLOAT16_t vr_10 = GiMultiplyFloat16(v_10, v_weight_10);
vr = GiAddFloat16(vr, vr_10);
GI_FLOAT16_t vr_11 = GiMultiplyFloat16(v_11, v_weight_11);
vr = GiAddFloat16(vr, vr_11);
GiStoreFloat16(dptr + get_offset(oh, ow, c, OH, OW, C), vr);
}
}
}
sptr += in_batch_stride;
dptr += out_batch_stride;
}
)";
} else {
CC_ASSERT(0);
}

return StringTemplate::StringTemplateArgs().add("core", core).render(ret);
}

static std::string GenNearestImpl() {
Expand Down Expand Up @@ -128,8 +182,7 @@ class ResizeHelper {
int S_IH = src_layout.stride[2];
int S_IW = src_layout.stride[3];
)";
} else {
CC_ASSERT(format == "NCHW44");
} else if (format == "NCHW44") {
ret = R"(
int N = src_layout.dims[0];
int C = src_layout.dims[1]*4;
Expand All @@ -138,6 +191,16 @@ class ResizeHelper {
int OH = dst_layout.dims[2];
int OW = dst_layout.dims[3];
)";
} else {
CC_ASSERT(format == "NCHW88");
ret = R"(
int N = src_layout.dims[0];
int C = src_layout.dims[1]*8;
int IH = src_layout.dims[2];
int IW = src_layout.dims[3];
int OH = dst_layout.dims[2];
int OW = dst_layout.dims[3];
)";
}
return ret;
}
Expand All @@ -150,14 +213,21 @@ class ResizeHelper {
return c * H * W + h * W + w;
}
)";
} else {
CC_ASSERT(format == "NCHW44");
} else if (format == "NCHW44") {
ret = R"(
static inline size_t get_offset(size_t h, size_t w, size_t c, size_t H, size_t W,
size_t C){
return (((c >> 2) * H * W + h * W + w) << 2) + (c & 3);
}
)";
} else {
CC_ASSERT(format == "NCHW88");
ret = R"(
static inline size_t get_offset(size_t h, size_t w, size_t c, size_t H, size_t W,
size_t C){
return ((c * H * W + h * W + w) << 3);
}
)";
}
return ret;
}
Expand Down
18 changes: 13 additions & 5 deletions compiler/lib/KernelGen/GeneralIntrinsic/Resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ using namespace GeneralIntrinsic;

bool ResizeKernel::IsAvailable(TContext* context) const {
auto src_dtype = context->getAttrOprand("operand:0").dtype;
bool dtype_ok = src_dtype == "f32";
bool dtype_ok_f32 = src_dtype == "f32";
bool dtype_ok_f16 = src_dtype == "f16";
bool mode_ok = context->getAttrStr("imode") == "LINEAR";
bool format_ok = context->getAttrStr("format") == "NCHW" ||
context->getAttrStr("format") == "NCHW44";
return dtype_ok && mode_ok && format_ok;
bool format_ok_f32 = context->getAttrStr("format") == "NCHW" ||
context->getAttrStr("format") == "NCHW44";
bool format_ok_f16 = context->getAttrStr("format") == "NCHW88";
return mode_ok &&
((dtype_ok_f32 && format_ok_f32) || (dtype_ok_f16 && format_ok_f16));
}
//! kernel gen
std::string ResizeKernel::GetKernelSymbol(TContext* context) const {
Expand All @@ -38,6 +41,11 @@ std::string ResizeKernel::GetKernelBody(TContext* context) const {
#include <math.h>
#include <stdalign.h>
)";
if (specifier == "gi_float16_t") {
ss << R"(
#include "gi_float16.h"
)";
}
auto coord_str = ResizeHelper::GenCoordHelper(imode, specifier);
auto gen_layout_dims = ResizeHelper::GenLayoutDims(fmt);
auto get_offset = ResizeHelper::GenGetOffset(fmt);
Expand Down Expand Up @@ -70,7 +78,7 @@ std::string ResizeKernel::GetKernelBody(TContext* context) const {
${normal_impl}
return TinyNN_SUCCESS;
})";
auto normal_impl = ResizeHelper::GenNormImpl(fmt);
auto normal_impl = ResizeHelper::GenNormImpl(fmt, specifier);
ss << StringTemplate::StringTemplateArgs()
.add("specifier", specifier)
.add("normal_impl", normal_impl)
Expand Down
11 changes: 11 additions & 0 deletions compiler/test/kernel/opr/generalIntrinsic/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,15 @@ TEST(GI, Resize) {
checker.execs({{1, 1, 5, 6}, {1, 1, 7, 13}});
checker.execs({{1, 4, 5, 6}, {1, 4, 9, 12}});
checker.execs({{2, 3, 15, 16}, {2, 3, 9, 12}});
#ifdef ENABLE_KERNEL_FP16
param.format = megdnn::ResizeForward::Param::Format::NCHW88;
checker.set_param(param);
checker.set_dtype(0, dtype::Float16());
checker.set_dtype(1, dtype::Float16());
checker.set_epsilon(5e-2);
checker.execs({{1, 1, 5, 6, 8}, {1, 1, 7, 13, 8}});
checker.execs({{1, 4, 5, 6, 8}, {1, 4, 9, 12, 8}});
checker.execs({{2, 3, 15, 16, 8}, {2, 3, 9, 12, 8}});
checker.execs({{2, 3, 1, 1, 8}, {2, 3, 9, 12, 8}});
#endif
}

0 comments on commit 4bb59d2

Please sign in to comment.