Skip to content

Commit

Permalink
ARM上的int4优化
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed May 28, 2024
1 parent 37853af commit 70d9739
Showing 1 changed file with 69 additions and 4 deletions.
73 changes: 69 additions & 4 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,11 @@ namespace fastllm {
}

for (; j < m; j++) {

int curWeight = (int)(*(weightWalk++));
for (int x = 0; x < curBlock; x++) {
values[x] += curWeight * (*(inputWalk + x * m));
}
inputWalk++;
}

for (int x = 0; x < curBlock; x++) {
Expand Down Expand Up @@ -1498,11 +1502,65 @@ namespace fastllm {
int *inputSums) :
a(a), b(b), c(c), n(n), m(m), k(k), kstride(kstride),
weightSums(weightSums), weightMins(weightMins), scales(scales), bias(bias), config(config), inputSums(inputSums) {}


#ifdef __ARM_FEATURE_DOTPROD
inline static void RunSomeBlock(uint8_t *weightWalk, uint8_t *inputStart, int32_t *c,
int curBlock, uint32x2_t *sum, uint8x8x2_t *vi,
int block, int k, int m, int kstride) {
uint8x8_t maskHigh = vdup_n_u8(0xF0);
uint8x8_t maskLow = vdup_n_u8(0xF);
for (int i = 0; i < k; i++) {
std::vector <int> values = std::vector <int> (curBlock, 0);
uint8_t *inputWalk = inputStart;
int j = 0;

for (int j = 0; j < curBlock; j++) {
sum[j][0] = sum[j][1] = 0;
}
for (; j + 15 < m; j += 16) {
for (int x = 0; x < curBlock; x++) {
vi[x] = vld2_u8(inputWalk + j + m * x);
}
uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2);
uint8x8_t va = vand_u8(ori, maskLow);
uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4);
for (int x = 0; x < curBlock; x++) {
sum[x] = vdot_u32(sum[x], va, vi[x].val[1]);
sum[x] = vdot_u32(sum[x], vb, vi[x].val[0]);
}
}
for (int x = 0; x < curBlock; x++) {
values[x] += sum[x][0] + sum[x][1];
}

for (; j + 1 < m; j += 2) {
int id = (i * m + j) / 2;
for (int x = 0; x < curBlock; x++) {
values[x] += (weightWalk[id] >> 4) * inputWalk[j + x * m];
values[x] += (weightWalk[id] & 0xF) * inputWalk[j + 1 + x * m];
}
}

for (int x = 0; x < curBlock; x++) {
c[(block + x) * kstride + i] = values[x];
}
}
}
#endif
void Run() {
#ifdef __ARM_FEATURE_DOTPROD
#define RUNBLOCK(x) for (; block + (x - 1) < n; block += (x)) RunSomeBlock(b, a + block * m, c, (x), sum, vi, block, k, m, kstride);
int block = 0;
uint32x2_t sum[16];
uint8x8x2_t vi[16];
RUNBLOCK(16);
RUNBLOCK(8);RUNBLOCK(7);RUNBLOCK(6);RUNBLOCK(5);
RUNBLOCK(4);RUNBLOCK(3);RUNBLOCK(2);RUNBLOCK(1);
#undef RUNBLOCK
#else
int block = 0;

for (; block < n; block++) {
uint32_t inputSum = inputSums[block];
uint8_t *weightWalk = b;
uint8_t *inputStart = a + block * m;

Expand Down Expand Up @@ -1549,9 +1607,16 @@ namespace fastllm {
value += (weightWalk[id] & 0xF) * inputWalk[j + 1];
}

c[block * kstride + i] = value;
}
}
#endif
for (int block = 0; block < n; block++) {
for (int i = 0; i < k; i++) {
int value = c[block * kstride + i];
value -= weightSums[i] * config[block].zeroPoint;
((float*)c)[block * kstride + i] = scales[i] * config[block].scale * value +
weightMins[i] * ((float)inputSum - (int)config[block].zeroPoint * m) * config[block].scale +
weightMins[i] * ((float)inputSums[block] - (int)config[block].zeroPoint * m) * config[block].scale +
(bias == nullptr ? 0.0 : bias[i]);
}
}
Expand Down

0 comments on commit 70d9739

Please sign in to comment.