From 70d973915e1a3cdbc742c61fbf879f79f2f80507 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Tue, 28 May 2024 18:59:01 +0800 Subject: [PATCH] =?UTF-8?q?ARM=E4=B8=8A=E7=9A=84int4=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/devices/cpu/cpudevice.cpp | 73 +++++++++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 4 deletions(-) diff --git a/src/devices/cpu/cpudevice.cpp b/src/devices/cpu/cpudevice.cpp index df96da52..20220a01 100644 --- a/src/devices/cpu/cpudevice.cpp +++ b/src/devices/cpu/cpudevice.cpp @@ -1172,7 +1172,11 @@ namespace fastllm { } for (; j < m; j++) { - + int curWeight = (int)(*(weightWalk++)); + for (int x = 0; x < curBlock; x++) { + values[x] += curWeight * (*(inputWalk + x * m)); + } + inputWalk++; } for (int x = 0; x < curBlock; x++) { @@ -1498,11 +1502,65 @@ namespace fastllm { int *inputSums) : a(a), b(b), c(c), n(n), m(m), k(k), kstride(kstride), weightSums(weightSums), weightMins(weightMins), scales(scales), bias(bias), config(config), inputSums(inputSums) {} - + +#ifdef __ARM_FEATURE_DOTPROD + inline static void RunSomeBlock(uint8_t *weightWalk, uint8_t *inputStart, int32_t *c, + int curBlock, uint32x2_t *sum, uint8x8x2_t *vi, + int block, int k, int m, int kstride) { + uint8x8_t maskHigh = vdup_n_u8(0xF0); + uint8x8_t maskLow = vdup_n_u8(0xF); + for (int i = 0; i < k; i++) { + std::vector values = std::vector (curBlock, 0); + uint8_t *inputWalk = inputStart; + int j = 0; + + for (int j = 0; j < curBlock; j++) { + sum[j][0] = sum[j][1] = 0; + } + for (; j + 15 < m; j += 16) { + for (int x = 0; x < curBlock; x++) { + vi[x] = vld2_u8(inputWalk + j + m * x); + } + uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2); + uint8x8_t va = vand_u8(ori, maskLow); + uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4); + for (int x = 0; x < curBlock; x++) { + sum[x] = vdot_u32(sum[x], va, vi[x].val[1]); + sum[x] = vdot_u32(sum[x], vb, vi[x].val[0]); + } + } + for (int x = 0; x < curBlock; x++) { + values[x] += sum[x][0] + sum[x][1]; + } + + for (; j + 1 < m; j += 2) { + int id = (i * m + j) / 2; + for (int x = 0; x < curBlock; x++) { + values[x] += (weightWalk[id] >> 4) * inputWalk[j + x * m]; + values[x] += (weightWalk[id] & 0xF) * inputWalk[j + 1 + x * m]; + } + } + + for (int x = 0; x < curBlock; x++) { + c[(block + x) * kstride + i] = values[x]; + } + } + } +#endif void Run() { +#ifdef __ARM_FEATURE_DOTPROD +#define RUNBLOCK(x) for (; block + (x - 1) < n; block += (x)) RunSomeBlock(b, a + block * m, c, (x), sum, vi, block, k, m, kstride); int block = 0; + uint32x2_t sum[16]; + uint8x8x2_t vi[16]; + RUNBLOCK(16); + RUNBLOCK(8);RUNBLOCK(7);RUNBLOCK(6);RUNBLOCK(5); + RUNBLOCK(4);RUNBLOCK(3);RUNBLOCK(2);RUNBLOCK(1); +#undef RUNBLOCK +#else + int block = 0; + for (; block < n; block++) { - uint32_t inputSum = inputSums[block]; uint8_t *weightWalk = b; uint8_t *inputStart = a + block * m; @@ -1549,9 +1607,16 @@ namespace fastllm { value += (weightWalk[id] & 0xF) * inputWalk[j + 1]; } + c[block * kstride + i] = value; + } + } +#endif + for (int block = 0; block < n; block++) { + for (int i = 0; i < k; i++) { + int value = c[block * kstride + i]; value -= weightSums[i] * config[block].zeroPoint; ((float*)c)[block * kstride + i] = scales[i] * config[block].scale * value + - weightMins[i] * ((float)inputSum - (int)config[block].zeroPoint * m) * config[block].scale + + weightMins[i] * ((float)inputSums[block] - (int)config[block].zeroPoint * m) * config[block].scale + (bias == nullptr ? 0.0 : bias[i]); } }