diff --git a/include/utils/utils.h b/include/utils/utils.h index a8bd2919..0803c791 100644 --- a/include/utils/utils.h +++ b/include/utils/utils.h @@ -21,6 +21,12 @@ #ifdef __AVX__ #include "immintrin.h" +#ifdef __GNUC__ +#if __GNUC__ < 8 +#define _mm256_set_m128i(/* __m128i */ hi, /* __m128i */ lo) \ + _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 0x1) +#endif +#endif #endif namespace fastllm { diff --git a/src/devices/cpu/cpudevice.cpp b/src/devices/cpu/cpudevice.cpp index 7e8d6f6a..e03e103d 100644 --- a/src/devices/cpu/cpudevice.cpp +++ b/src/devices/cpu/cpudevice.cpp @@ -79,7 +79,7 @@ namespace fastllm { return true; } -#ifdef __AVX__ + #ifdef __AVX2__ int DotU8U8(uint8_t *a, uint8_t *b, int n) { __m256i acc = _mm256_setzero_si256(); @@ -107,32 +107,31 @@ namespace fastllm { return ans + I32sum(acc); }; -#else - int DotU8U8(uint8_t *a, uint8_t *b, int n) { - __m256i acc = _mm256_setzero_si256(); - - int i = 0; - int ans = 0; - for (; i + 31 < n; i += 32) { - __m256i bx = _mm256_loadu_si256((const __m256i *) (a + i)); - __m256i by = _mm256_loadu_si256((const __m256i *) (b + i)); - - __m256i mx0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 0)); - __m256i mx1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 1)); - - __m256i my0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 0)); - __m256i my1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 1)); - - acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx0, my0)); - acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx1, my1)); - } - for (; i < n; i++) { - ans += a[i] * b[i]; - } - - return ans + I32sum(acc); - }; -#endif +//#else +// int DotU8U8(uint8_t *a, uint8_t *b, int n) { +// __m256i acc = _mm256_setzero_si256(); + +// int i = 0; +// int ans = 0; +// for (; i + 31 < n; i += 32) { +// __m256i bx = _mm256_loadu_si256((const __m256i *) (a + i)); +// __m256i by = _mm256_loadu_si256((const __m256i *) (b + i)); + +// __m256i mx0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 0)); +// __m256i mx1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 1)); + +// __m256i my0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 0)); +// __m256i my1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 1)); + +// acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx0, my0)); +// //acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx1, my1)); +// } +// for (; i < n; i++) { +// ans += a[i] * b[i]; +// } + +// return ans + I32sum(acc); +// }; int DotU4U8(uint8_t *a, uint8_t *b, int n) { __m256i acc = _mm256_setzero_si256(); @@ -920,7 +919,7 @@ namespace fastllm { c[block * kstride + i] = value; } } -#elif defined(__AVX__) +#elif defined(__AVX2__) int block = 0; for (; block < n; block++) { uint8_t *weightWalk = b; @@ -994,7 +993,7 @@ namespace fastllm { sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0])); } value += sum0[0] + sum0[1] + sum0[2] + sum0[3]; -#elif defined(__AVX__) +#elif defined(__AVX2__) value += DotU4U8(weightWalk + i * m / 2, inputWalk, m); j += m; #endif @@ -1065,7 +1064,7 @@ namespace fastllm { sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0])); } value += sum0[0] + sum0[1] + sum0[2] + sum0[3]; -#elif defined(__AVX__) +#elif defined(__AVX2__) value += DotU4U8(weightWalk + i * m / 2, inputWalk, m); j += m; #endif diff --git a/src/fastllm.cpp b/src/fastllm.cpp index 9a8421c3..f9c97f0b 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -562,7 +562,7 @@ namespace fastllm { weightSum.resize(n); for (int i = 0; i < n; i++) { int j = 0; -#ifdef __AVX__ +#ifdef __AVX2__ __m256i acc = _mm256_setzero_si256(); const __m256i ones = _mm256_set1_epi16(1); for (; j + 31 < m; j += 32) { @@ -608,7 +608,7 @@ namespace fastllm { } weightSum[i] += sum0[0] + sum0[1] + sum0[2] + sum0[3]; #endif -#ifdef __AVX__ +#ifdef __AVX2__ __m256i acc = _mm256_setzero_si256(); const __m256i lowMask = _mm256_set1_epi8(0xf); const __m256i ones = _mm256_set1_epi16(1);