Skip to content

Commit

Permalink
临时修复仅有AVX而无AVX2时的编译错误(#64 #194)
Browse files Browse the repository at this point in the history
需要进一步优化DotU8U8等计算逻辑
  • Loading branch information
cgli committed Sep 24, 2023
1 parent 2325d87 commit e3f1db0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 32 deletions.
59 changes: 29 additions & 30 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ namespace fastllm {
return true;
}

#ifdef __AVX__

#ifdef __AVX2__
int DotU8U8(uint8_t *a, uint8_t *b, int n) {
__m256i acc = _mm256_setzero_si256();
Expand Down Expand Up @@ -107,32 +107,31 @@ namespace fastllm {

return ans + I32sum(acc);
};
#else
int DotU8U8(uint8_t *a, uint8_t *b, int n) {
__m256i acc = _mm256_setzero_si256();

int i = 0;
int ans = 0;
for (; i + 31 < n; i += 32) {
__m256i bx = _mm256_loadu_si256((const __m256i *) (a + i));
__m256i by = _mm256_loadu_si256((const __m256i *) (b + i));

__m256i mx0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 0));
__m256i mx1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 1));

__m256i my0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 0));
__m256i my1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 1));

acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx0, my0));
acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx1, my1));
}
for (; i < n; i++) {
ans += a[i] * b[i];
}

return ans + I32sum(acc);
};
#endif
//#else
// int DotU8U8(uint8_t *a, uint8_t *b, int n) {
// __m256i acc = _mm256_setzero_si256();

// int i = 0;
// int ans = 0;
// for (; i + 31 < n; i += 32) {
// __m256i bx = _mm256_loadu_si256((const __m256i *) (a + i));
// __m256i by = _mm256_loadu_si256((const __m256i *) (b + i));

// __m256i mx0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 0));
// __m256i mx1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 1));

// __m256i my0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 0));
// __m256i my1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 1));

// acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx0, my0));
// //acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx1, my1));
// }
// for (; i < n; i++) {
// ans += a[i] * b[i];
// }

// return ans + I32sum(acc);
// };
int DotU4U8(uint8_t *a, uint8_t *b, int n) {
__m256i acc = _mm256_setzero_si256();

Expand Down Expand Up @@ -920,7 +919,7 @@ namespace fastllm {
c[block * kstride + i] = value;
}
}
#elif defined(__AVX__)
#elif defined(__AVX2__)
int block = 0;
for (; block < n; block++) {
uint8_t *weightWalk = b;
Expand Down Expand Up @@ -994,7 +993,7 @@ namespace fastllm {
sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0]));
}
value += sum0[0] + sum0[1] + sum0[2] + sum0[3];
#elif defined(__AVX__)
#elif defined(__AVX2__)
value += DotU4U8(weightWalk + i * m / 2, inputWalk, m);
j += m;
#endif
Expand Down Expand Up @@ -1065,7 +1064,7 @@ namespace fastllm {
sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0]));
}
value += sum0[0] + sum0[1] + sum0[2] + sum0[3];
#elif defined(__AVX__)
#elif defined(__AVX2__)
value += DotU4U8(weightWalk + i * m / 2, inputWalk, m);
j += m;
#endif
Expand Down
4 changes: 2 additions & 2 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ namespace fastllm {
weightSum.resize(n);
for (int i = 0; i < n; i++) {
int j = 0;
#ifdef __AVX__
#ifdef __AVX2__
__m256i acc = _mm256_setzero_si256();
const __m256i ones = _mm256_set1_epi16(1);
for (; j + 31 < m; j += 32) {
Expand Down Expand Up @@ -607,7 +607,7 @@ namespace fastllm {
}
weightSum[i] += sum0[0] + sum0[1] + sum0[2] + sum0[3];
#endif
#ifdef __AVX__
#ifdef __AVX2__
__m256i acc = _mm256_setzero_si256();
const __m256i lowMask = _mm256_set1_epi8(0xf);
const __m256i ones = _mm256_set1_epi16(1);
Expand Down

0 comments on commit e3f1db0

Please sign in to comment.