From 4bb59d278eef88346832b00de3fd2801e70b94a4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 16 Apr 2024 17:02:16 +0800 Subject: [PATCH] feat(kernel): add GI Float16 NCHW88 Resize kernel GitOrigin-RevId: 3b9b11abf0ccafd741dcbfcc374dedd0727f5bb8 --- compiler/lib/KernelGen/Common/Resize.h | 116 ++++++++++++++---- .../lib/KernelGen/GeneralIntrinsic/Resize.cpp | 18 ++- .../kernel/opr/generalIntrinsic/resize.cpp | 11 ++ 3 files changed, 117 insertions(+), 28 deletions(-) diff --git a/compiler/lib/KernelGen/Common/Resize.h b/compiler/lib/KernelGen/Common/Resize.h index 2e7ad3b1..c8f58adc 100644 --- a/compiler/lib/KernelGen/Common/Resize.h +++ b/compiler/lib/KernelGen/Common/Resize.h @@ -10,7 +10,7 @@ class ResizeHelper { static std::string GenCoordHelper( const std::string& imode, const std::string& specifier) { CC_ASSERT(imode == "LINEAR"); - CC_ASSERT(specifier == "float"); + CC_ASSERT(specifier == "float" || specifier == "gi_float16_t"); std::string body = R"( #include static inline void get_coord(float scale, int size, int idx, float* ah0, int* ih0, float* ah1, int* ih1){ @@ -19,27 +19,29 @@ class ResizeHelper { *ih0 = 0; *ah1 = 0.f; *ih1 = 0; - } - float alpha = (idx + 0.5f) / scale - 0.5f; - int origin_idx = (int)(floorf(alpha)); - alpha -= origin_idx; + } else { + float alpha = (idx + 0.5f) / scale - 0.5f; + int origin_idx = (int)(floorf(alpha)); + alpha -= origin_idx; - if (origin_idx < 0) { - origin_idx = 0; - alpha = 0; - } else if (origin_idx + 1 >= size) { - origin_idx = size - 2; - alpha = 1; + if (origin_idx < 0) { + origin_idx = 0; + alpha = 0; + } else if (origin_idx + 1 >= size) { + origin_idx = size - 2; + alpha = 1; + } + *ah0 = 1 - alpha; + *ih0 = origin_idx; + *ah1 = alpha; + *ih1 = origin_idx + 1; } - *ah0 = 1 - alpha; - *ih0 = origin_idx; - *ah1 = alpha; - *ih1 = origin_idx + 1; } )"; return body; } - static std::string GenNormImpl(const std::string& format) { + static std::string GenNormImpl( + const std::string& format, const std::string& specifier = "float") { std::string ret; ret = R"( { @@ -62,7 +64,13 @@ class ResizeHelper { int out_batch_stride = C * oc_stride; int ic_stride = IH * IW; int in_batch_stride = C * ic_stride; - + ${core} + } + )"; + + std::string core; + if (specifier == "float") { + core = R"( rep(n, N){ rep(c, C) { rep(oh, OH) { @@ -87,10 +95,56 @@ class ResizeHelper { sptr += in_batch_stride; dptr += out_batch_stride; } - } )"; + } else if (specifier == "gi_float16_t") { + core = R"( + C /= 8; + rep(n, N){ + rep(c, C) { + rep(oh, OH) { + int ih0 = ih0_cache[oh]; + int ih1 = ih1_cache[oh]; + gi_float16_t ah0 = (gi_float16_t)(ah0_cache[oh]); + gi_float16_t ah1 = (gi_float16_t)(ah1_cache[oh]); + GI_FLOAT16_t v_ah0 = GiBroadcastFloat16(ah0); + GI_FLOAT16_t v_ah1 = GiBroadcastFloat16(ah1); + rep(ow, OW) { + int iw0 = iw0_cache[ow]; + int iw1 = iw1_cache[ow]; + gi_float16_t aw0 = (gi_float16_t)(aw0_cache[ow]); + gi_float16_t aw1 = (gi_float16_t)(aw1_cache[ow]); + GI_FLOAT16_t v_aw0 = GiBroadcastFloat16(aw0); + GI_FLOAT16_t v_aw1 = GiBroadcastFloat16(aw1); - return ret; + GI_FLOAT16_t v_00 = GiLoadFloat16(sptr + get_offset(ih0, iw0, c, IH, IW, C)); + GI_FLOAT16_t v_weight_00 = GiMultiplyFloat16(v_ah0, v_aw0); + GI_FLOAT16_t v_01 = GiLoadFloat16(sptr + get_offset(ih0, iw1, c, IH, IW, C)); + GI_FLOAT16_t v_weight_01 = GiMultiplyFloat16(v_ah0, v_aw1); + GI_FLOAT16_t v_10 = GiLoadFloat16(sptr + get_offset(ih1, iw0, c, IH, IW, C)); + GI_FLOAT16_t v_weight_10 = GiMultiplyFloat16(v_ah1, v_aw0); + GI_FLOAT16_t v_11 = GiLoadFloat16(sptr + get_offset(ih1, iw1, c, IH, IW, C)); + GI_FLOAT16_t v_weight_11 = GiMultiplyFloat16(v_ah1, v_aw1); + + GI_FLOAT16_t vr_00 = GiMultiplyFloat16(v_00, v_weight_00); + GI_FLOAT16_t vr_01 = GiMultiplyFloat16(v_01, v_weight_01); + GI_FLOAT16_t vr = GiAddFloat16(vr_00, vr_01); + GI_FLOAT16_t vr_10 = GiMultiplyFloat16(v_10, v_weight_10); + vr = GiAddFloat16(vr, vr_10); + GI_FLOAT16_t vr_11 = GiMultiplyFloat16(v_11, v_weight_11); + vr = GiAddFloat16(vr, vr_11); + GiStoreFloat16(dptr + get_offset(oh, ow, c, OH, OW, C), vr); + } + } + } + sptr += in_batch_stride; + dptr += out_batch_stride; + } + )"; + } else { + CC_ASSERT(0); + } + + return StringTemplate::StringTemplateArgs().add("core", core).render(ret); } static std::string GenNearestImpl() { @@ -128,8 +182,7 @@ class ResizeHelper { int S_IH = src_layout.stride[2]; int S_IW = src_layout.stride[3]; )"; - } else { - CC_ASSERT(format == "NCHW44"); + } else if (format == "NCHW44") { ret = R"( int N = src_layout.dims[0]; int C = src_layout.dims[1]*4; @@ -138,6 +191,16 @@ class ResizeHelper { int OH = dst_layout.dims[2]; int OW = dst_layout.dims[3]; )"; + } else { + CC_ASSERT(format == "NCHW88"); + ret = R"( + int N = src_layout.dims[0]; + int C = src_layout.dims[1]*8; + int IH = src_layout.dims[2]; + int IW = src_layout.dims[3]; + int OH = dst_layout.dims[2]; + int OW = dst_layout.dims[3]; + )"; } return ret; } @@ -150,14 +213,21 @@ class ResizeHelper { return c * H * W + h * W + w; } )"; - } else { - CC_ASSERT(format == "NCHW44"); + } else if (format == "NCHW44") { ret = R"( static inline size_t get_offset(size_t h, size_t w, size_t c, size_t H, size_t W, size_t C){ return (((c >> 2) * H * W + h * W + w) << 2) + (c & 3); } )"; + } else { + CC_ASSERT(format == "NCHW88"); + ret = R"( + static inline size_t get_offset(size_t h, size_t w, size_t c, size_t H, size_t W, + size_t C){ + return ((c * H * W + h * W + w) << 3); + } + )"; } return ret; } diff --git a/compiler/lib/KernelGen/GeneralIntrinsic/Resize.cpp b/compiler/lib/KernelGen/GeneralIntrinsic/Resize.cpp index 3d679a21..2fd26b3e 100644 --- a/compiler/lib/KernelGen/GeneralIntrinsic/Resize.cpp +++ b/compiler/lib/KernelGen/GeneralIntrinsic/Resize.cpp @@ -12,11 +12,14 @@ using namespace GeneralIntrinsic; bool ResizeKernel::IsAvailable(TContext* context) const { auto src_dtype = context->getAttrOprand("operand:0").dtype; - bool dtype_ok = src_dtype == "f32"; + bool dtype_ok_f32 = src_dtype == "f32"; + bool dtype_ok_f16 = src_dtype == "f16"; bool mode_ok = context->getAttrStr("imode") == "LINEAR"; - bool format_ok = context->getAttrStr("format") == "NCHW" || - context->getAttrStr("format") == "NCHW44"; - return dtype_ok && mode_ok && format_ok; + bool format_ok_f32 = context->getAttrStr("format") == "NCHW" || + context->getAttrStr("format") == "NCHW44"; + bool format_ok_f16 = context->getAttrStr("format") == "NCHW88"; + return mode_ok && + ((dtype_ok_f32 && format_ok_f32) || (dtype_ok_f16 && format_ok_f16)); } //! kernel gen std::string ResizeKernel::GetKernelSymbol(TContext* context) const { @@ -38,6 +41,11 @@ std::string ResizeKernel::GetKernelBody(TContext* context) const { #include #include )"; + if (specifier == "gi_float16_t") { + ss << R"( + #include "gi_float16.h" + )"; + } auto coord_str = ResizeHelper::GenCoordHelper(imode, specifier); auto gen_layout_dims = ResizeHelper::GenLayoutDims(fmt); auto get_offset = ResizeHelper::GenGetOffset(fmt); @@ -70,7 +78,7 @@ std::string ResizeKernel::GetKernelBody(TContext* context) const { ${normal_impl} return TinyNN_SUCCESS; })"; - auto normal_impl = ResizeHelper::GenNormImpl(fmt); + auto normal_impl = ResizeHelper::GenNormImpl(fmt, specifier); ss << StringTemplate::StringTemplateArgs() .add("specifier", specifier) .add("normal_impl", normal_impl) diff --git a/compiler/test/kernel/opr/generalIntrinsic/resize.cpp b/compiler/test/kernel/opr/generalIntrinsic/resize.cpp index 7a2eb4f0..eb087516 100644 --- a/compiler/test/kernel/opr/generalIntrinsic/resize.cpp +++ b/compiler/test/kernel/opr/generalIntrinsic/resize.cpp @@ -16,4 +16,15 @@ TEST(GI, Resize) { checker.execs({{1, 1, 5, 6}, {1, 1, 7, 13}}); checker.execs({{1, 4, 5, 6}, {1, 4, 9, 12}}); checker.execs({{2, 3, 15, 16}, {2, 3, 9, 12}}); +#ifdef ENABLE_KERNEL_FP16 + param.format = megdnn::ResizeForward::Param::Format::NCHW88; + checker.set_param(param); + checker.set_dtype(0, dtype::Float16()); + checker.set_dtype(1, dtype::Float16()); + checker.set_epsilon(5e-2); + checker.execs({{1, 1, 5, 6, 8}, {1, 1, 7, 13, 8}}); + checker.execs({{1, 4, 5, 6, 8}, {1, 4, 9, 12, 8}}); + checker.execs({{2, 3, 15, 16, 8}, {2, 3, 9, 12, 8}}); + checker.execs({{2, 3, 1, 1, 8}, {2, 3, 9, 12, 8}}); +#endif } \ No newline at end of file