Skip to content

Commit

Permalink
Merge pull request #333 from TylunasLi/bug_fix_att_opt
Browse files Browse the repository at this point in the history
支持GCC 7.x编译,以及在没有AVX2的CPU上执行
  • Loading branch information
ztxz16 authored Sep 26, 2023
2 parents 909b4d9 + e3f1db0 commit 73c15c0
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 32 deletions.
6 changes: 6 additions & 0 deletions include/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@

#ifdef __AVX__
#include "immintrin.h"
#ifdef __GNUC__
#if __GNUC__ < 8
#define _mm256_set_m128i(/* __m128i */ hi, /* __m128i */ lo) \
_mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 0x1)
#endif
#endif
#endif

namespace fastllm {
Expand Down
59 changes: 29 additions & 30 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ namespace fastllm {
return true;
}

#ifdef __AVX__

#ifdef __AVX2__
int DotU8U8(uint8_t *a, uint8_t *b, int n) {
__m256i acc = _mm256_setzero_si256();
Expand Down Expand Up @@ -107,32 +107,31 @@ namespace fastllm {

return ans + I32sum(acc);
};
#else
int DotU8U8(uint8_t *a, uint8_t *b, int n) {
__m256i acc = _mm256_setzero_si256();

int i = 0;
int ans = 0;
for (; i + 31 < n; i += 32) {
__m256i bx = _mm256_loadu_si256((const __m256i *) (a + i));
__m256i by = _mm256_loadu_si256((const __m256i *) (b + i));

__m256i mx0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 0));
__m256i mx1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 1));

__m256i my0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 0));
__m256i my1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 1));

acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx0, my0));
acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx1, my1));
}
for (; i < n; i++) {
ans += a[i] * b[i];
}

return ans + I32sum(acc);
};
#endif
//#else
// int DotU8U8(uint8_t *a, uint8_t *b, int n) {
// __m256i acc = _mm256_setzero_si256();

// int i = 0;
// int ans = 0;
// for (; i + 31 < n; i += 32) {
// __m256i bx = _mm256_loadu_si256((const __m256i *) (a + i));
// __m256i by = _mm256_loadu_si256((const __m256i *) (b + i));

// __m256i mx0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 0));
// __m256i mx1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 1));

// __m256i my0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 0));
// __m256i my1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 1));

// acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx0, my0));
// //acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx1, my1));
// }
// for (; i < n; i++) {
// ans += a[i] * b[i];
// }

// return ans + I32sum(acc);
// };
int DotU4U8(uint8_t *a, uint8_t *b, int n) {
__m256i acc = _mm256_setzero_si256();

Expand Down Expand Up @@ -920,7 +919,7 @@ namespace fastllm {
c[block * kstride + i] = value;
}
}
#elif defined(__AVX__)
#elif defined(__AVX2__)
int block = 0;
for (; block < n; block++) {
uint8_t *weightWalk = b;
Expand Down Expand Up @@ -994,7 +993,7 @@ namespace fastllm {
sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0]));
}
value += sum0[0] + sum0[1] + sum0[2] + sum0[3];
#elif defined(__AVX__)
#elif defined(__AVX2__)
value += DotU4U8(weightWalk + i * m / 2, inputWalk, m);
j += m;
#endif
Expand Down Expand Up @@ -1065,7 +1064,7 @@ namespace fastllm {
sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0]));
}
value += sum0[0] + sum0[1] + sum0[2] + sum0[3];
#elif defined(__AVX__)
#elif defined(__AVX2__)
value += DotU4U8(weightWalk + i * m / 2, inputWalk, m);
j += m;
#endif
Expand Down
4 changes: 2 additions & 2 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ namespace fastllm {
weightSum.resize(n);
for (int i = 0; i < n; i++) {
int j = 0;
#ifdef __AVX__
#ifdef __AVX2__
__m256i acc = _mm256_setzero_si256();
const __m256i ones = _mm256_set1_epi16(1);
for (; j + 31 < m; j += 32) {
Expand Down Expand Up @@ -608,7 +608,7 @@ namespace fastllm {
}
weightSum[i] += sum0[0] + sum0[1] + sum0[2] + sum0[3];
#endif
#ifdef __AVX__
#ifdef __AVX2__
__m256i acc = _mm256_setzero_si256();
const __m256i lowMask = _mm256_set1_epi8(0xf);
const __m256i ones = _mm256_set1_epi16(1);
Expand Down

0 comments on commit 73c15c0

Please sign in to comment.