From 46b01fecf9e8dd58043721037083fced9c0e5f43 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 14 Sep 2023 13:44:14 +0800 Subject: [PATCH] opt(kernel): fuse im2col and pack of stride 2 im2col convbias based i8mm m8n12k8 gemm GitOrigin-RevId: 3a2f1bffd2f4465a987c1f62201e0c069b6db2ee --- .../Int8/Int8I8mmIm2colS1NCHW44M8N12K8.cpp | 1393 ++++++++++++++++- compiler/test/kernel/opr/arm/conv.cpp | 12 +- 2 files changed, 1371 insertions(+), 34 deletions(-) diff --git a/compiler/lib/KernelGen/Arm/Arm64/ConvKernel/Int8/Int8I8mmIm2colS1NCHW44M8N12K8.cpp b/compiler/lib/KernelGen/Arm/Arm64/ConvKernel/Int8/Int8I8mmIm2colS1NCHW44M8N12K8.cpp index 2ac91afe..ae1798e6 100644 --- a/compiler/lib/KernelGen/Arm/Arm64/ConvKernel/Int8/Int8I8mmIm2colS1NCHW44M8N12K8.cpp +++ b/compiler/lib/KernelGen/Arm/Arm64/ConvKernel/Int8/Int8I8mmIm2colS1NCHW44M8N12K8.cpp @@ -96,6 +96,1294 @@ static inline void im2col(const int8_t *src, int8_t *dst, const int IC, const in } )"; } + +std::string fuse_im2col_packB_s2(TContext* ctx) { + const int FH = ctx->getAttrInt("kernel_h"), FW = ctx->getAttrInt("kernel_w"); + CC_ASSERT(FH == FW); + std::string res = R"( +static void fuse_im2col_packB_s2(const int8_t *src, int8_t *dst, const int IC, + const int IH, const int IW, const int OH, const int OW) { + TINYNN_ASSERT(OW >= 12);)"; + if (FW % 2) { + res += R"( + int32x4x2_t d[15]; + int32_t buffer[6][24];)"; + } + res += R"( + int32_t *dst_ptr = (int32_t *)dst; + const int N = OH * OW; + int n = 0; + for (; n + 11 < N; n += 12) { + const int oh = n / OW, ow = n % OW; + const int32_t *src_base = + (const int32_t *)(src + oh * 2 * IW * 4 + ow * 2 * 4); + if (OW - ow >= 12) {)"; + if (FW % 2) { + res += R"( + int ic = 0; + for (; ic + 1 < IC; ic += 2) {)"; + int idx = 0; + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", 96);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx2", (idx + 2) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx4", (idx + 4) % 15) + .add("idx5", (idx + 5) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("idx8", (idx + 8) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 24 + FW / 2 * 24) + .add("row0_idx", fh * 2) + .add("row1_idx", fh * 2 + 1) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + d[${idx1}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 8); + d[${idx2}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 16); + d[${idx3}] = vld2q_s32(src_base + IW * ${row1_idx}); + d[${idx4}] = vld2q_s32(src_base + IW * ${row1_idx} + 8); + d[${idx5}] = vld2q_s32(src_base + IW * ${row1_idx} + 16); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx4}].val[0]); + d[${idx8}] = vzipq_s32(d[${idx2}].val[0], d[${idx5}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 16, d[${idx8}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 20, d[${idx8}].val[1]); + )"); + idx = (idx + 9) % 15; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + (FW + 1) / 2 * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", 96);"); + } + } + int fh = FH / 2; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + std::string(", 96);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx2", (idx + 2) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx4", (idx + 4) % 15) + .add("idx5", (idx + 5) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("idx8", (idx + 8) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 24 + FW / 2 * 24) + .add("row0_idx", fh * 2) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + d[${idx1}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 8); + d[${idx2}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 16); + src_base += (IH * IW); + d[${idx3}] = vld2q_s32(src_base); + d[${idx4}] = vld2q_s32(src_base + 8); + d[${idx5}] = vld2q_s32(src_base + 16); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx4}].val[0]); + d[${idx8}] = vzipq_s32(d[${idx2}].val[0], d[${idx5}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 16, d[${idx8}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 20, d[${idx8}].val[1]); + )"); + idx = (idx + 9) % 15; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + (FW + 1) / 2 * 24 + fw * 24) + + std::string(", src_base + ") + std::to_string(fw * 2 + 1) + + std::string(", 96);"); + } + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string((FH + 1) / 2 * FW * 24 + fh * FW * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", 96);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx2", (idx + 2) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx4", (idx + 4) % 15) + .add("idx5", (idx + 5) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("idx8", (idx + 8) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", + (FH + 1) / 2 * FW * 24 + fh * FW * 24 + FW / 2 * 24) + .add("row0_idx", fh * 2 + 1) + .add("row1_idx", fh * 2 + 2) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + d[${idx1}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 8); + d[${idx2}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 16); + d[${idx3}] = vld2q_s32(src_base + IW * ${row1_idx}); + d[${idx4}] = vld2q_s32(src_base + IW * ${row1_idx} + 8); + d[${idx5}] = vld2q_s32(src_base + IW * ${row1_idx} + 16); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx4}].val[0]); + d[${idx8}] = vzipq_s32(d[${idx2}].val[0], d[${idx5}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 16, d[${idx8}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 20, d[${idx8}].val[1]); + )"); + idx = (idx + 9) % 15; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string( + (FH + 1) / 2 * FW * 24 + fh * FW * 24 + + (FW + 1) / 2 * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 2) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", 96);"); + } + } + res += StringTemplate::StringTemplateArgs() + .add("offset", FH * FW * 24) + .render(R"( + dst_ptr += ${offset}; + src_base += (IH * IW); + } + if (ic < IC) { + )"); + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", 96);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx2", (idx + 2) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx4", (idx + 4) % 15) + .add("idx5", (idx + 5) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("idx8", (idx + 8) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 24 + FW / 2 * 24) + .add("row0_idx", fh * 2) + .add("row1_idx", fh * 2 + 1) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + d[${idx1}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 8); + d[${idx2}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 16); + d[${idx3}] = vld2q_s32(src_base + IW * ${row1_idx}); + d[${idx4}] = vld2q_s32(src_base + IW * ${row1_idx} + 8); + d[${idx5}] = vld2q_s32(src_base + IW * ${row1_idx} + 16); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx4}].val[0]); + d[${idx8}] = vzipq_s32(d[${idx2}].val[0], d[${idx5}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 16, d[${idx8}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 20, d[${idx8}].val[1]); + )"); + idx = (idx + 9) % 15; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + (FW + 1) / 2 * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", 96);"); + } + } + fh = FH / 2; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + std::string(", 96);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx2", (idx + 2) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("idx8", (idx + 8) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 24 + FW / 2 * 24) + .add("row0_idx", fh * 2) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + d[${idx1}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 8); + d[${idx2}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 15); + src_base += (IH * IW); + d[${idx3}].val[0] = vdupq_n_s32(0); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx3}].val[0]); + d[${idx8}] = vzipq_s32(d[${idx2}].val[1], d[${idx3}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 16, d[${idx8}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 20, d[${idx8}].val[1]); + + dst_ptr += (${dst_offset} + 24); + } + } else { + const int32_t* src_base_next = (const int32_t*)(src + (oh + 1) * 2 * IW * 4); + const int part0 = OW - ow, part1 = 12 - part0; + int ic = 0; + for (; ic + 1 < IC; ic += 2) { + )"); + idx = (idx + 9) % 15; + int buffer_idx = 0; + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part0 * 2 * 4);"); + res += std::string("memcpy(dst_ptr + part0 * 2 + ") + + std::to_string(fh * FW * 24 + fw * 24) + + std::string(", src_base_next + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part1 * 2 * 4);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx2", (idx + 2) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx4", (idx + 4) % 15) + .add("idx5", (idx + 5) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("idx8", (idx + 8) % 15) + .add("buffer_idx0", buffer_idx % 6) + .add("buffer_idx1", (buffer_idx + 1) % 6) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 24 + FW / 2 * 24) + .add("row0_idx", fh * 2) + .add("row1_idx", fh * 2 + 1) + .render(R"( + memcpy(buffer[${buffer_idx0}], src_base + IW * ${row0_idx} + ${src_offset}, part0 * 2 * 4); + memcpy(buffer[${buffer_idx0}] + part0 * 2, src_base_next + IW * ${row0_idx} + ${src_offset}, part1 * 2 * 4); + memcpy(buffer[${buffer_idx1}], src_base + IW * ${row1_idx}, part0 * 2 * 4); + memcpy(buffer[${buffer_idx1}] + part0 * 2, src_base_next + IW * ${row1_idx}, part1 * 2 * 4); + d[${idx0}] = vld2q_s32(buffer[${buffer_idx0}]); + d[${idx1}] = vld2q_s32(buffer[${buffer_idx0}] + 8); + d[${idx2}] = vld2q_s32(buffer[${buffer_idx0}] + 16); + d[${idx3}] = vld2q_s32(buffer[${buffer_idx1}]); + d[${idx4}] = vld2q_s32(buffer[${buffer_idx1}] + 8); + d[${idx5}] = vld2q_s32(buffer[${buffer_idx1}] + 16); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx4}].val[0]); + d[${idx8}] = vzipq_s32(d[${idx2}].val[0], d[${idx5}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 16, d[${idx8}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 20, d[${idx8}].val[1]); + )"); + idx = (idx + 9) % 15; + buffer_idx = (buffer_idx + 2) % 6; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + (FW + 1) / 2 * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", part0 * 2 * 4);"); + res += std::string("memcpy(dst_ptr + part0 * 2 + ") + + std::to_string(fh * FW * 24 + (FW + 1) / 2 * 24 + fw * 24) + + std::string(", src_base_next + IW * ") + + std::to_string(fh * 2 + 1) + std::string(" + ") + + std::to_string(fw * 2 + 1) + std::string(", part1 * 2 * 4);"); + } + } + fh = FH / 2; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part0 * 2 * 4);"); + res += std::string("memcpy(dst_ptr + part0 * 2 + ") + + std::to_string(fh * FW * 24 + fw * 24) + + std::string(", src_base_next + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part1 * 2 * 4);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx2", (idx + 2) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx4", (idx + 4) % 15) + .add("idx5", (idx + 5) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("idx8", (idx + 8) % 15) + .add("buffer_idx0", buffer_idx % 6) + .add("buffer_idx1", (buffer_idx + 1) % 6) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 24 + FW / 2 * 24) + .add("row0_idx", fh * 2) + .render(R"( + memcpy(buffer[${buffer_idx0}], src_base + IW * ${row0_idx} + ${src_offset}, part0 * 2 * 4); + memcpy(buffer[${buffer_idx0}] + part0 * 2, src_base_next + IW * ${row0_idx} + ${src_offset}, part1 * 2 * 4); + src_base += (IH * IW); + src_base_next += (IH * IW); + memcpy(buffer[${buffer_idx1}], src_base, part0 * 2 * 4); + memcpy(buffer[${buffer_idx1}] + part0 * 2, src_base_next, part1 * 2 * 4); + d[${idx0}] = vld2q_s32(buffer[${buffer_idx0}]); + d[${idx1}] = vld2q_s32(buffer[${buffer_idx0}] + 8); + d[${idx2}] = vld2q_s32(buffer[${buffer_idx0}] + 16); + d[${idx3}] = vld2q_s32(buffer[${buffer_idx1}]); + d[${idx4}] = vld2q_s32(buffer[${buffer_idx1}] + 8); + d[${idx5}] = vld2q_s32(buffer[${buffer_idx1}] + 16); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx4}].val[0]); + d[${idx8}] = vzipq_s32(d[${idx2}].val[0], d[${idx5}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 16, d[${idx8}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 20, d[${idx8}].val[1]); + )"); + idx = (idx + 9) % 15; + buffer_idx = (buffer_idx + 2) % 6; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + (FW + 1) / 2 * 24 + fw * 24) + + std::string(", src_base + ") + std::to_string(fw * 2 + 1) + + std::string(", part0 * 2 * 4);"); + res += std::string("memcpy(dst_ptr + part0 * 2 + ") + + std::to_string(fh * FW * 24 + (FW + 1) / 2 * 24 + fw * 24) + + std::string(", src_base_next + ") + std::to_string(fw * 2 + 1) + + std::string(", part1 * 2 * 4);"); + } + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string((FH + 1) / 2 * FW * 24 + fh * FW * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part0 * 2 * 4);"); + res += std::string("memcpy(dst_ptr + part0 * 2 + ") + + std::to_string((FH + 1) / 2 * FW * 24 + fh * FW * 24 + fw * 24) + + std::string(", src_base_next + IW * ") + + std::to_string(fh * 2 + 1) + std::string(" + ") + + std::to_string(fw * 2) + std::string(", part1 * 2 * 4);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx2", (idx + 2) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx4", (idx + 4) % 15) + .add("idx5", (idx + 5) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("idx8", (idx + 8) % 15) + .add("buffer_idx0", buffer_idx % 6) + .add("buffer_idx1", (buffer_idx + 1) % 6) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", + (FH + 1) / 2 * FW * 24 + fh * FW * 24 + FW / 2 * 24) + .add("row0_idx", fh * 2 + 1) + .add("row1_idx", fh * 2 + 2) + .render(R"( + memcpy(buffer[${buffer_idx0}], src_base + IW * ${row0_idx} + ${src_offset}, part0 * 2 * 4); + memcpy(buffer[${buffer_idx0}] + part0 * 2, src_base_next + IW * ${row0_idx} + ${src_offset}, part1 * 2 * 4); + memcpy(buffer[${buffer_idx1}], src_base + IW * ${row1_idx}, part0 * 2 * 4); + memcpy(buffer[${buffer_idx1}] + part0 * 2, src_base_next + IW * ${row1_idx}, part1 * 2 * 4); + d[${idx0}] = vld2q_s32(buffer[${buffer_idx0}]); + d[${idx1}] = vld2q_s32(buffer[${buffer_idx0}] + 8); + d[${idx2}] = vld2q_s32(buffer[${buffer_idx0}] + 16); + d[${idx3}] = vld2q_s32(buffer[${buffer_idx1}]); + d[${idx4}] = vld2q_s32(buffer[${buffer_idx1}] + 8); + d[${idx5}] = vld2q_s32(buffer[${buffer_idx1}] + 16); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx4}].val[0]); + d[${idx8}] = vzipq_s32(d[${idx2}].val[0], d[${idx5}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 16, d[${idx8}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 20, d[${idx8}].val[1]); + )"); + idx = (idx + 9) % 15; + buffer_idx = (buffer_idx + 2) % 6; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string( + (FH + 1) / 2 * FW * 24 + fh * FW * 24 + + (FW + 1) / 2 * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 2) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", part0 * 2 * 4);"); + res += std::string("memcpy(dst_ptr + part0 * 2 + ") + + std::to_string( + (FH + 1) / 2 * FW * 24 + fh * FW * 24 + + (FW + 1) / 2 * 24 + fw * 24) + + std::string(", src_base_next + IW * ") + + std::to_string(fh * 2 + 2) + std::string(" + ") + + std::to_string(fw * 2 + 1) + std::string(", part1 * 2 * 4);"); + } + } + res += StringTemplate::StringTemplateArgs() + .add("offset", FH * FW * 24) + .render(R"( + dst_ptr += ${offset}; + src_base += (IH * IW); + src_base_next += (IH * IW); + } + if (ic < IC) { + )"); + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part0 * 2 * 4);"); + res += std::string("memcpy(dst_ptr + part0 * 2 + ") + + std::to_string(fh * FW * 24 + fw * 24) + + std::string(", src_base_next + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part1 * 2 * 4);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx2", (idx + 2) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx4", (idx + 4) % 15) + .add("idx5", (idx + 5) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("idx8", (idx + 8) % 15) + .add("buffer_idx0", buffer_idx % 6) + .add("buffer_idx1", (buffer_idx + 1) % 6) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 24 + FW / 2 * 24) + .add("row0_idx", fh * 2) + .add("row1_idx", fh * 2 + 1) + .render(R"( + memcpy(buffer[${buffer_idx0}], src_base + IW * ${row0_idx} + ${src_offset}, part0 * 2 * 4); + memcpy(buffer[${buffer_idx0}] + part0 * 2, src_base_next + IW * ${row0_idx} + ${src_offset}, part1 * 2 * 4); + memcpy(buffer[${buffer_idx1}], src_base + IW * ${row1_idx}, part0 * 2 * 4); + memcpy(buffer[${buffer_idx1}] + part0 * 2, src_base_next + IW * ${row1_idx}, part1 * 2 * 4); + d[${idx0}] = vld2q_s32(buffer[${buffer_idx0}]); + d[${idx1}] = vld2q_s32(buffer[${buffer_idx0}] + 8); + d[${idx2}] = vld2q_s32(buffer[${buffer_idx0}] + 16); + d[${idx3}] = vld2q_s32(buffer[${buffer_idx1}]); + d[${idx4}] = vld2q_s32(buffer[${buffer_idx1}] + 8); + d[${idx5}] = vld2q_s32(buffer[${buffer_idx1}] + 16); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx4}].val[0]); + d[${idx8}] = vzipq_s32(d[${idx2}].val[0], d[${idx5}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 16, d[${idx8}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 20, d[${idx8}].val[1]); + )"); + idx = (idx + 9) % 15; + buffer_idx = (buffer_idx + 2) % 6; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + (FW + 1) / 2 * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", part0 * 2 * 4);"); + res += std::string("memcpy(dst_ptr + part0 * 2 + ") + + std::to_string(fh * FW * 24 + (FW + 1) / 2 * 24 + fw * 24) + + std::string(", src_base_next + IW * ") + + std::to_string(fh * 2 + 1) + std::string(" + ") + + std::to_string(fw * 2 + 1) + std::string(", part1 * 2 * 4);"); + } + } + fh = FH / 2; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part0 * 2 * 4);"); + res += std::string("memcpy(dst_ptr + part0 * 2 + ") + + std::to_string(fh * FW * 24 + fw * 24) + + std::string(", src_base_next + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part1 * 2 * 4);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx2", (idx + 2) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx4", (idx + 4) % 15) + .add("idx5", (idx + 5) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("idx8", (idx + 8) % 15) + .add("buffer_idx0", buffer_idx % 6) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 24 + FW / 2 * 24) + .add("row0_idx", fh * 2) + .render(R"( + memcpy(buffer[${buffer_idx0}], src_base + IW * ${row0_idx} + ${src_offset}, part0 * 2 * 4); + memcpy(buffer[${buffer_idx0}] + part0 * 2, src_base_next + IW * ${row0_idx} + ${src_offset}, part1 * 2 * 4); + src_base += (IH * IW); + src_base_next += (IH * IW); + d[${idx0}] = vld2q_s32(buffer[${buffer_idx0}]); + d[${idx1}] = vld2q_s32(buffer[${buffer_idx0}] + 8); + d[${idx2}] = vld2q_s32(buffer[${buffer_idx0}] + 16); + d[${idx3}].val[0] = vdupq_n_s32(0); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx3}].val[0]); + d[${idx8}] = vzipq_s32(d[${idx2}].val[0], d[${idx3}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 16, d[${idx8}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 20, d[${idx8}].val[1]); + + dst_ptr += (${dst_offset} + 24); + } + } + } + for (; n + 7 < N; n += 8) { + const int oh = n / OW, ow = n % OW; + const int32_t *src_base = + (const int32_t *)(src + oh * 2 * IW * 4 + ow * 2 * 4); + TINYNN_ASSERT(OW - ow >= 8); + int ic = 0; + for (; ic + 1 < IC; ic += 2) { + )"); + idx = (idx + 9) % 15; + buffer_idx = (buffer_idx + 2) % 6; + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 16 + fw * 16) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", 64);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx4", (idx + 4) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 16 + FW / 2 * 16) + .add("row0_idx", fh * 2) + .add("row1_idx", fh * 2 + 1) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + d[${idx1}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 8); + d[${idx3}] = vld2q_s32(src_base + IW * ${row1_idx}); + d[${idx4}] = vld2q_s32(src_base + IW * ${row1_idx} + 8); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx4}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + )"); + idx = (idx + 9) % 15; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 16 + (FW + 1) / 2 * 16 + fw * 16) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", 64);"); + } + } + fh = FH / 2; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 16 + fw * 16) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + std::string(", 64);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx4", (idx + 4) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 16 + FW / 2 * 16) + .add("row0_idx", fh * 2) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + d[${idx1}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 8); + src_base += (IH * IW); + d[${idx3}] = vld2q_s32(src_base); + d[${idx4}] = vld2q_s32(src_base + 8); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx4}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + )"); + idx = (idx + 9) % 15; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 16 + (FW + 1) / 2 * 16 + fw * 16) + + std::string(", src_base + ") + std::to_string(fw * 2 + 1) + + std::string(", 64);"); + } + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string((FH + 1) / 2 * FW * 16 + fh * FW * 16 + fw * 16) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", 64);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx4", (idx + 4) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", + (FH + 1) / 2 * FW * 16 + fh * FW * 16 + FW / 2 * 16) + .add("row0_idx", fh * 2 + 1) + .add("row1_idx", fh * 2 + 2) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + d[${idx1}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 8); + d[${idx3}] = vld2q_s32(src_base + IW * ${row1_idx}); + d[${idx4}] = vld2q_s32(src_base + IW * ${row1_idx} + 8); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx4}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + )"); + idx = (idx + 9) % 15; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string( + (FH + 1) / 2 * FW * 16 + fh * FW * 16 + + (FW + 1) / 2 * 16 + fw * 16) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 2) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", 64);"); + } + } + res += StringTemplate::StringTemplateArgs() + .add("offset", FH * FW * 16) + .render(R"( + dst_ptr += ${offset}; + src_base += (IH * IW); + } + if (ic < IC) { + )"); + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 16 + fw * 16) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", 64);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx4", (idx + 4) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 16 + FW / 2 * 16) + .add("row0_idx", fh * 2) + .add("row1_idx", fh * 2 + 1) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + d[${idx1}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 8); + d[${idx3}] = vld2q_s32(src_base + IW * ${row1_idx}); + d[${idx4}] = vld2q_s32(src_base + IW * ${row1_idx} + 8); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[0], d[${idx4}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + )"); + idx = (idx + 9) % 15; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 16 + (FW + 1) / 2 * 16 + fw * 16) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", 64);"); + } + } + fh = FH / 2; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 16 + fw * 16) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + std::string(", 64);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx1", (idx + 1) % 15) + .add("idx3", (idx + 3) % 15) + .add("idx6", (idx + 6) % 15) + .add("idx7", (idx + 7) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 16 + FW / 2 * 16) + .add("row0_idx", fh * 2) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + d[${idx1}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset} + 7); + src_base += (IH * IW); + d[${idx3}].val[0] = vdupq_n_s32(0); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + d[${idx7}] = vzipq_s32(d[${idx1}].val[1], d[${idx3}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + vst1q_s32(dst_ptr + ${dst_offset} + 8, d[${idx7}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 12, d[${idx7}].val[1]); + + dst_ptr += (${dst_offset} + 16); + } + } + for (; n < N; n += 4) { + const int oh = n / OW, ow = n % OW; + const int32_t* src_base = (const int32_t*)(src + oh * 2 * IW * 4 + ow * 2 * 4); + TINYNN_ASSERT(oh + 1 == OH); + if (OW - ow >= 4) { + int ic = 0; + for (; ic + 1 < IC; ic += 2) { + )"); + idx = (idx + 9) % 15; + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", 32);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx3", (idx + 3) % 15) + .add("idx6", (idx + 6) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 8 + FW / 2 * 8) + .add("row0_idx", fh * 2) + .add("row1_idx", fh * 2 + 1) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + d[${idx3}] = vld2q_s32(src_base + IW * ${row1_idx}); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + )"); + idx = (idx + 9) % 15; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + (FW + 1) / 2 * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", 32);"); + } + } + fh = FH / 2; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + std::string(", 32);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx3", (idx + 3) % 15) + .add("idx6", (idx + 6) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 8 + FW / 2 * 8) + .add("row0_idx", fh * 2) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + src_base += (IH * IW); + d[${idx3}] = vld2q_s32(src_base); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + )"); + idx = (idx + 9) % 15; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + (FW + 1) / 2 * 8 + fw * 8) + + std::string(", src_base + ") + std::to_string(fw * 2 + 1) + + std::string(", 32);"); + } + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string((FH + 1) / 2 * FW * 8 + fh * FW * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", 32);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx3", (idx + 3) % 15) + .add("idx6", (idx + 6) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", + (FH + 1) / 2 * FW * 8 + fh * FW * 8 + FW / 2 * 8) + .add("row0_idx", fh * 2 + 1) + .add("row1_idx", fh * 2 + 2) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + d[${idx3}] = vld2q_s32(src_base + IW * ${row1_idx}); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + )"); + idx = (idx + 9) % 15; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string( + (FH + 1) / 2 * FW * 8 + fh * FW * 8 + (FW + 1) / 2 * 8 + + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 2) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", 32);"); + } + } + res += StringTemplate::StringTemplateArgs() + .add("offset", FH * FW * 8) + .render(R"( + dst_ptr += ${offset}; + src_base += (IH * IW); + } + if (ic < IC) { + )"); + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", 32);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx3", (idx + 3) % 15) + .add("idx6", (idx + 6) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 8 + FW / 2 * 8) + .add("row0_idx", fh * 2) + .add("row1_idx", fh * 2 + 1) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + d[${idx3}] = vld2q_s32(src_base + IW * ${row1_idx}); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + )"); + idx = (idx + 9) % 15; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + (FW + 1) / 2 * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", 32);"); + } + } + fh = FH / 2; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + std::string(", 32);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx3", (idx + 3) % 15) + .add("idx6", (idx + 6) % 15) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 8 + FW / 2 * 8) + .add("row0_idx", fh * 2) + .render(R"( + d[${idx0}] = vld2q_s32(src_base + IW * ${row0_idx} + ${src_offset}); + src_base += (IH * IW); + d[${idx3}].val[0] = vdupq_n_s32(0); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + + dst_ptr += (${dst_offset} + 8); + } + } else { + const int part0 = OW - ow; + int ic = 0; + for (; ic + 1 < IC; ic += 2) { + )"); + idx = (idx + 9) % 15; + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part0 * 2 * 4);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx3", (idx + 3) % 15) + .add("idx6", (idx + 6) % 15) + .add("buffer_idx0", buffer_idx % 6) + .add("buffer_idx1", (buffer_idx + 1) % 6) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 8 + FW / 2 * 8) + .add("row0_idx", fh * 2) + .add("row1_idx", fh * 2 + 1) + .render(R"( + memcpy(buffer[${buffer_idx0}], src_base + IW * ${row0_idx} + ${src_offset}, part0 * 2 * 4); + memcpy(buffer[${buffer_idx1}], src_base + IW * ${row1_idx}, part0 * 2 * 4); + d[${idx0}] = vld2q_s32(buffer[${buffer_idx0}]); + d[${idx3}] = vld2q_s32(buffer[${buffer_idx1}]); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + )"); + idx = (idx + 9) % 15; + buffer_idx = (buffer_idx + 2) % 6; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + (FW + 1) / 2 * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", part0 * 2 * 4);"); + } + } + fh = FH / 2; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part0 * 2 * 4);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx3", (idx + 3) % 15) + .add("idx6", (idx + 6) % 15) + .add("buffer_idx0", buffer_idx % 6) + .add("buffer_idx1", (buffer_idx + 1) % 6) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 8 + FW / 2 * 8) + .add("row0_idx", fh * 2) + .render(R"( + memcpy(buffer[${buffer_idx0}], src_base + IW * ${row0_idx} + ${src_offset}, part0 * 2 * 4); + src_base += (IH * IW); + memcpy(buffer[${buffer_idx1}], src_base, part0 * 2 * 4); + d[${idx0}] = vld2q_s32(buffer[${buffer_idx0}]); + d[${idx3}] = vld2q_s32(buffer[${buffer_idx1}]); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + )"); + idx = (idx + 9) % 15; + buffer_idx = (buffer_idx + 2) % 6; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + (FW + 1) / 2 * 8 + fw * 8) + + std::string(", src_base + ") + std::to_string(fw * 2 + 1) + + std::string(", part0 * 2 * 4);"); + } + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string((FH + 1) / 2 * FW * 8 + fh * FW * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part0 * 2 * 4);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx3", (idx + 3) % 15) + .add("idx6", (idx + 6) % 15) + .add("buffer_idx0", buffer_idx % 6) + .add("buffer_idx1", (buffer_idx + 1) % 6) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", + (FH + 1) / 2 * FW * 8 + fh * FW * 8 + FW / 2 * 8) + .add("row0_idx", fh * 2 + 1) + .add("row1_idx", fh * 2 + 2) + .render(R"( + memcpy(buffer[${buffer_idx0}], src_base + IW * ${row0_idx} + ${src_offset}, part0 * 2 * 4); + memcpy(buffer[${buffer_idx1}], src_base + IW * ${row1_idx}, part0 * 2 * 4); + d[${idx0}] = vld2q_s32(buffer[${buffer_idx0}]); + d[${idx3}] = vld2q_s32(buffer[${buffer_idx1}]); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + )"); + idx = (idx + 9) % 15; + buffer_idx = (buffer_idx + 2) % 6; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string( + (FH + 1) / 2 * FW * 8 + fh * FW * 8 + (FW + 1) / 2 * 8 + + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 2) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", part0 * 2 * 4);"); + } + } + res += StringTemplate::StringTemplateArgs() + .add("offset", FH * FW * 8) + .render(R"( + dst_ptr += ${offset}; + src_base += (IH * IW); + } + if (ic < IC) { + )"); + for (int fh = 0; fh < FH / 2; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part0 * 2 * 4);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx3", (idx + 3) % 15) + .add("idx6", (idx + 6) % 15) + .add("buffer_idx0", buffer_idx % 6) + .add("buffer_idx1", (buffer_idx + 1) % 6) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 8 + FW / 2 * 8) + .add("row0_idx", fh * 2) + .add("row1_idx", fh * 2 + 1) + .render(R"( + memcpy(buffer[${buffer_idx0}], src_base + IW * ${row0_idx} + ${src_offset}, part0 * 2 * 4); + memcpy(buffer[${buffer_idx1}], src_base + IW * ${row1_idx}, part0 * 2 * 4); + d[${idx0}] = vld2q_s32(buffer[${buffer_idx0}]); + d[${idx3}] = vld2q_s32(buffer[${buffer_idx1}]); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + )"); + idx = (idx + 9) % 15; + buffer_idx = (buffer_idx + 2) % 6; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + (FW + 1) / 2 * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2 + 1) + + std::string(" + ") + std::to_string(fw * 2 + 1) + + std::string(", part0 * 2 * 4);"); + } + } + fh = FH / 2; + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * FW * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh * 2) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part0 * 2 * 4);"); + } + res += StringTemplate::StringTemplateArgs() + .add("idx0", idx % 15) + .add("idx3", (idx + 3) % 15) + .add("idx6", (idx + 6) % 15) + .add("buffer_idx0", buffer_idx % 6) + .add("src_offset", FW / 2 * 2) + .add("dst_offset", fh * FW * 8 + FW / 2 * 8) + .add("row0_idx", fh * 2) + .render(R"( + memcpy(buffer[${buffer_idx0}], src_base + IW * ${row0_idx} + ${src_offset}, part0 * 2 * 4); + src_base += (IH * IW); + d[${idx0}] = vld2q_s32(buffer[${buffer_idx0}]); + d[${idx3}].val[0] = vdupq_n_s32(0); + d[${idx6}] = vzipq_s32(d[${idx0}].val[0], d[${idx3}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset}, d[${idx6}].val[0]); + vst1q_s32(dst_ptr + ${dst_offset} + 4, d[${idx6}].val[1]); + + dst_ptr += (${dst_offset} + 8); + } + } + } + } + )"); + idx = (idx + 9) % 15; + buffer_idx = (buffer_idx + 2) % 6; + } else { + res += R"( + for (int ic = 0; ic < IC; ++ic) {)"; + for (int fh = 0; fh < FH; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * (FW / 2) * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", 96);"); + } + } + res += StringTemplate::StringTemplateArgs() + .add("offset", FH * FW * 12) + .render(R"( + dst_ptr += ${offset}; + src_base += (IH * IW); + } + )"); + res += R"( + } else { + const int32_t* src_base_next = (const int32_t*)(src + (oh + 1) * 2 * IW * 4); + const int part0 = OW - ow, part1 = 12 - part0; + for (int ic = 0; ic < IC; ++ic) { + )"; + for (int fh = 0; fh < FH; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * (FW / 2) * 24 + fw * 24) + + std::string(", src_base + IW * ") + std::to_string(fh) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part0 * 2 * 4);"); + res += std::string("memcpy(dst_ptr + part0 * 2 + ") + + std::to_string(fh * (FW / 2) * 24 + fw * 24) + + std::string(", src_base_next + IW * ") + std::to_string(fh) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part1 * 2 * 4);"); + } + } + res += StringTemplate::StringTemplateArgs() + .add("offset", FH * FW * 12) + .render(R"( + dst_ptr += ${offset}; + src_base += (IH * IW); + src_base_next += (IH * IW); + } + } + } + for (; n + 7 < N; n += 8) { + const int oh = n / OW, ow = n % OW; + const int32_t *src_base = + (const int32_t *)(src + oh * 2 * IW * 4 + ow * 2 * 4); + TINYNN_ASSERT(OW - ow >= 8); + for (int ic = 0; ic < IC; ++ic) {)"); + for (int fh = 0; fh < FH; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * (FW / 2) * 16 + fw * 16) + + std::string(", src_base + IW * ") + std::to_string(fh) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", 64);"); + } + } + res += StringTemplate::StringTemplateArgs() + .add("offset", FH * FW * 8) + .render(R"( + dst_ptr += ${offset}; + src_base += (IH * IW); + } + } + for (; n < N; n += 4) { + const int oh = n / OW, ow = n % OW; + const int32_t *src_base = + (const int32_t *)(src + oh * 2 * IW * 4 + ow * 2 * 4); + TINYNN_ASSERT(oh + 1 == OH); + if (OW - ow >= 4) { + for (int ic = 0; ic < IC; ++ic) { + )"); + for (int fh = 0; fh < FH; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * (FW / 2) * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", 32);"); + } + } + res += StringTemplate::StringTemplateArgs() + .add("offset", FH * FW * 4) + .render(R"( + dst_ptr += ${offset}; + src_base += (IH * IW); + } + )"); + res += R"( + } else { + const int part0 = OW - ow; + for (int ic = 0; ic < IC; ++ic) { + )"; + for (int fh = 0; fh < FH; ++fh) { + for (int fw = 0; fw < FW / 2; ++fw) { + res += std::string("memcpy(dst_ptr + ") + + std::to_string(fh * (FW / 2) * 8 + fw * 8) + + std::string(", src_base + IW * ") + std::to_string(fh) + + std::string(" + ") + std::to_string(fw * 2) + + std::string(", part0 * 2 * 4);"); + } + } + res += StringTemplate::StringTemplateArgs() + .add("offset", FH * FW * 4) + .render(R"( + dst_ptr += ${offset}; + src_base += (IH * IW); + } + } + } +} + )"); + } + return res; +} } // namespace bool ConvBiasIm2colI8mmNCHW44::IsAvailable(TContext* ctx) const { @@ -201,6 +1489,8 @@ std::string ConvBiasIm2colI8mmNCHW44::GetWorkspaceBodyCondition( ss << GenCommonRet() << " " << GetWorkspaceSignature(ctx); std::string workspace_temp = R"({ TINYNN_ASSERT(workspace); + TINYNN_ASSERT(${stride_h} == ${stride_w}); + TINYNN_ASSERT(${kernel_h} == ${kernel_w}); ${group} const Layout src_layout = inputs[0]->layout; const size_t IC = src_layout.dims[1] / group * 4; @@ -217,7 +1507,7 @@ std::string ConvBiasIm2colI8mmNCHW44::GetWorkspaceBodyCondition( const size_t OW = (padded_IW - ${kernel_w}) / ${stride_w} + 1; const size_t K = IC * ${kernel_h} * ${kernel_w}, N = OH * OW; size_t im2col_size = 0; - if (${kernel_h} != 1 || ${kernel_w} != 1 || ${stride_h} != 1 || ${stride_w} != 1){ + if ((${kernel_h} != 1 && ${stride_h} == 1) || (${stride_h} == 2 && OW < 12)){ im2col_size = K * N * sizeof(int8_t); } @@ -290,6 +1580,7 @@ std::string ConvBiasIm2colI8mmNCHW44::GetKernelBody(TContext* ctx) const { ctx->getAttrInt("stride_w") == 2); writer << "#include \n"; writer << im2col_s2(); + writer << fuse_im2col_packB_s2(ctx); } } writer << GenCommonRet() << " " << GetKernelSignature(ctx); @@ -297,21 +1588,7 @@ std::string ConvBiasIm2colI8mmNCHW44::GetKernelBody(TContext* ctx) const { auto last_dtype = Utils::get_last_operand(ctx).dtype; auto last_dtype_str = SymbolHelper::gen_valid_dtype(last_dtype); std::string dst_specifier = Utils::cvt_dtype_specifier(last_dtype_str); - writer << StringTemplate::StringTemplateArgs(ctx) - .add("bias_ptr_str", bias_ptr_str) - .add("packb_size_sym", - m_inner_gemm.GetPackBWorkspaceSymbol(inner_ctx.get())) - .add("packb_sym", m_inner_gemm.GetPackBSymbol(inner_ctx.get())) - .add("naked_kern_sym", - m_inner_gemm.GetNakedKernelSymbol(inner_ctx.get())) - .add("dst_specifier", dst_specifier) - .add_ctx_int("pad_h") - .add_ctx_int("pad_w") - .add_ctx_int("kernel_h") - .add_ctx_int("kernel_w") - .add_ctx_int("stride_h") - .add_ctx_int("stride_w") - .render(R"({ + std::string temp_body = R"({ int8_t* input_data = inputs[0]->ptr; ${dst_specifier}* output_data = outputs[0]->ptr; @@ -344,25 +1621,21 @@ std::string ConvBiasIm2colI8mmNCHW44::GetKernelBody(TContext* ctx) const { if ((${pad_h} != 0) || (${pad_w} != 0)) { pad_size = in_c * padded_ih * padded_iw * sizeof(int8_t); } - if (${kernel_h} != 1 || ${kernel_w} != 1 || ${stride_h} != 1 || ${stride_w} != 1){ + if ((${kernel_h} != 1 && ${stride_h} == 1) || (${stride_h} == 2 && out_w < 12)){ im2col_size = K * N * sizeof(int8_t); } void *pad_ws = workspace->ptr; void *im2col_ws = pad_ws + pad_size; void *packb_ws = im2col_ws + im2col_size; + const int pad_h = ${pad_h}, pad_w = ${pad_w}, kernel_h = ${kernel_h}, kernel_w = ${kernel_w};)"; + if (ctx->getAttrInt("stride_h") == 1) { + temp_body += R"( for (int n_idx = 0; n_idx < in_n; ++n_idx) { int32_t* bias_data = ${bias_ptr_str}; int8_t* weight_data = inputs[1]->ptr; - for (int g = 0; g < group; ++g) {)" + - (need_pad ? std::string(R"( - pad_src(input_data, pad_ws, in_c / 4, in_h, in_w, ${pad_h}, ${pad_w});)") - : std::string(R"( - pad_ws = input_data;)")) + - (need_im2col ? std::string(R"( - im2col(pad_ws, im2col_ws, in_c / 4, out_h, out_w, ${kernel_h}, ${kernel_w}, padded_ih, padded_iw);)") - : std::string(R"( - im2col_ws = pad_ws;)")) + - std::string(R"( + for (int g = 0; g < group; ++g) { + ${exec_pad} + ${exec_im2col} ${packb_sym}(packb_ws, im2col_ws, LDB, 0, N, 0, K); ${naked_kern_sym}(weight_data, packb_ws, output_data, LDC, M, N, K, bias_data, NULL, scale, temp_scale, dst_scale_inv); weight_data += weight_layout.stride[0]; @@ -372,8 +1645,72 @@ std::string ConvBiasIm2colI8mmNCHW44::GetKernelBody(TContext* ctx) const { } } return TinyNN_SUCCESS; -})")); +})"; + } else { + CC_ASSERT(ctx->getAttrInt("stride_h") == 2); + temp_body += R"( + //! Because of the implementation of function `fuse_im2col_packB_s2` assumes that the 12 elements span at most two lines. + if (out_w < 12) { + for (int n_idx = 0; n_idx < in_n; ++n_idx) { + int32_t* bias_data = ${bias_ptr_str}; + int8_t* weight_data = inputs[1]->ptr; + for (int g = 0; g < group; ++g) { + ${exec_pad} + ${exec_im2col} + ${packb_sym}(packb_ws, im2col_ws, LDB, 0, N, 0, K); + ${naked_kern_sym}(weight_data, packb_ws, output_data, LDC, M, N, K, bias_data, NULL, scale, temp_scale, dst_scale_inv); + weight_data += weight_layout.stride[0]; + bias_data += out_c; + input_data += in_c * in_h * in_w; + output_data += out_c * out_h * out_w; + } + } + } else { + for (int n_idx = 0; n_idx < in_n; ++n_idx) { + int32_t* bias_data = ${bias_ptr_str}; + int8_t* weight_data = inputs[1]->ptr; + for (int g = 0; g < group; ++g) { + ${exec_pad} + fuse_im2col_packB_s2(pad_ws, packb_ws, in_c / 4, padded_ih, padded_iw, out_h, out_w); + ${naked_kern_sym}(weight_data, packb_ws, output_data, LDC, M, N, K, bias_data, NULL, scale, temp_scale, dst_scale_inv); + weight_data += weight_layout.stride[0]; + bias_data += out_c; + input_data += in_c * in_h * in_w; + output_data += out_c * out_h * out_w; + } + } + } + return TinyNN_SUCCESS; +})"; + } + std::string exec_pad = + (need_pad ? std::string(R"( + pad_src(input_data, pad_ws, in_c / 4, in_h, in_w, pad_h, pad_w);)") + : std::string(R"( + pad_ws = input_data;)")); + std::string exec_im2col = + (need_im2col ? std::string(R"( + im2col(pad_ws, im2col_ws, in_c / 4, out_h, out_w, kernel_h, kernel_w, + padded_ih, padded_iw);)") + : std::string(R"( im2col_ws = pad_ws;)")); + writer << StringTemplate::StringTemplateArgs(ctx) + .add("bias_ptr_str", bias_ptr_str) + .add("packb_size_sym", + m_inner_gemm.GetPackBWorkspaceSymbol(inner_ctx.get())) + .add("packb_sym", m_inner_gemm.GetPackBSymbol(inner_ctx.get())) + .add("naked_kern_sym", + m_inner_gemm.GetNakedKernelSymbol(inner_ctx.get())) + .add("dst_specifier", dst_specifier) + .add_ctx_int("pad_h") + .add_ctx_int("pad_w") + .add_ctx_int("kernel_h") + .add_ctx_int("kernel_w") + .add_ctx_int("stride_h") + .add_ctx_int("stride_w") + .add("exec_pad", exec_pad) + .add("exec_im2col", exec_im2col) + .render(temp_body); return writer.str(); } diff --git a/compiler/test/kernel/opr/arm/conv.cpp b/compiler/test/kernel/opr/arm/conv.cpp index 0615f61d..eed1bc27 100644 --- a/compiler/test/kernel/opr/arm/conv.cpp +++ b/compiler/test/kernel/opr/arm/conv.cpp @@ -145,7 +145,7 @@ TEST(AARCH64, ConvBias1x1NCHW44I8mm) { param.nonlineMode = noline; checker.set_param(param); for (size_t ic : {3, 4, 5}) { - for (size_t ohw = 7; ohw < 27; ++ohw) { + for (size_t ohw = 7 * stride; ohw < 35 * stride; ohw += stride) { checker.execs( {{2, ic, 1, ohw, 4}, {5, ic, 1, 1, 4, 4}, {}, {}, {}}); checker.execs( @@ -167,7 +167,7 @@ TEST(AARCH64, ConvBias1x1NCHW44I8mm) { param.nonlineMode = noline; checker.set_param(param); for (size_t ic : {3, 4, 5}) { - for (size_t ohw = 7; ohw < 27; ++ohw) { + for (size_t ohw = 7 * stride; ohw < 27 * stride; ohw += stride) { checker.execs( {{2, ic * group, 1, ohw, 4}, {group, 5, ic, 1, 1, 4, 4}, @@ -212,12 +212,12 @@ TEST(AARCH64, ConvBiasIm2colNCHW44I8mm) { ConvBiasForward::Param::NonlineMode::RELU, ConvBiasForward::Param::NonlineMode::H_SWISH}) { param.nonlineMode = noline; - for (size_t kernel : {2, 3, 5, 7}) { + for (size_t kernel : {2, 3, 4, 5, 7}) { param.pad_h = kernel / 2; param.pad_w = kernel / 2; checker.set_param(param); for (size_t ic : {3, 4, 5}) { - for (size_t ohw = 7; ohw < 27; ++ohw) { + for (size_t ohw = 7 * stride; ohw < 35 * stride; ohw += stride) { checker.execs( {{2, ic, 1, ohw, 4}, {5, ic, kernel, kernel, 4, 4}, @@ -242,12 +242,12 @@ TEST(AARCH64, ConvBiasIm2colNCHW44I8mm) { ConvBiasForward::Param::NonlineMode::RELU, ConvBiasForward::Param::NonlineMode::H_SWISH}) { param.nonlineMode = noline; - for (size_t kernel : {2, 3, 5, 7}) { + for (size_t kernel : {2, 3, 4, 5, 7}) { param.pad_h = kernel / 2; param.pad_w = kernel / 2; checker.set_param(param); for (size_t ic : {3, 4, 5}) { - for (size_t ohw = 7; ohw < 27; ++ohw) { + for (size_t ohw = 7 * stride; ohw < 27 * stride; ohw += stride) { checker.execs( {{2, ic * group, 3, ohw, 4}, {group, 5, ic, kernel, kernel, 4, 4},