From e3f1db029ba5284b6d5801a377fbf6c856ffd14c Mon Sep 17 00:00:00 2001 From: cgli Date: Sat, 26 Aug 2023 16:24:48 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=B4=E6=97=B6=E4=BF=AE=E5=A4=8D=E4=BB=85?= =?UTF-8?q?=E6=9C=89AVX=E8=80=8C=E6=97=A0AVX2=E6=97=B6=E7=9A=84=E7=BC=96?= =?UTF-8?q?=E8=AF=91=E9=94=99=E8=AF=AF(#64=20#194)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 需要进一步优化DotU8U8等计算逻辑 --- src/devices/cpu/cpudevice.cpp | 59 +++++++++++++++++------------------ src/fastllm.cpp | 4 +-- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/src/devices/cpu/cpudevice.cpp b/src/devices/cpu/cpudevice.cpp index 7e8d6f6a..e03e103d 100644 --- a/src/devices/cpu/cpudevice.cpp +++ b/src/devices/cpu/cpudevice.cpp @@ -79,7 +79,7 @@ namespace fastllm { return true; } -#ifdef __AVX__ + #ifdef __AVX2__ int DotU8U8(uint8_t *a, uint8_t *b, int n) { __m256i acc = _mm256_setzero_si256(); @@ -107,32 +107,31 @@ namespace fastllm { return ans + I32sum(acc); }; -#else - int DotU8U8(uint8_t *a, uint8_t *b, int n) { - __m256i acc = _mm256_setzero_si256(); - - int i = 0; - int ans = 0; - for (; i + 31 < n; i += 32) { - __m256i bx = _mm256_loadu_si256((const __m256i *) (a + i)); - __m256i by = _mm256_loadu_si256((const __m256i *) (b + i)); - - __m256i mx0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 0)); - __m256i mx1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 1)); - - __m256i my0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 0)); - __m256i my1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 1)); - - acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx0, my0)); - acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx1, my1)); - } - for (; i < n; i++) { - ans += a[i] * b[i]; - } - - return ans + I32sum(acc); - }; -#endif +//#else +// int DotU8U8(uint8_t *a, uint8_t *b, int n) { +// __m256i acc = _mm256_setzero_si256(); + +// int i = 0; +// int ans = 0; +// for (; i + 31 < n; i += 32) { +// __m256i bx = _mm256_loadu_si256((const __m256i *) (a + i)); +// __m256i by = _mm256_loadu_si256((const __m256i *) (b + i)); + +// __m256i mx0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 0)); +// __m256i mx1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 1)); + +// __m256i my0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 0)); +// __m256i my1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 1)); + +// acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx0, my0)); +// //acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx1, my1)); +// } +// for (; i < n; i++) { +// ans += a[i] * b[i]; +// } + +// return ans + I32sum(acc); +// }; int DotU4U8(uint8_t *a, uint8_t *b, int n) { __m256i acc = _mm256_setzero_si256(); @@ -920,7 +919,7 @@ namespace fastllm { c[block * kstride + i] = value; } } -#elif defined(__AVX__) +#elif defined(__AVX2__) int block = 0; for (; block < n; block++) { uint8_t *weightWalk = b; @@ -994,7 +993,7 @@ namespace fastllm { sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0])); } value += sum0[0] + sum0[1] + sum0[2] + sum0[3]; -#elif defined(__AVX__) +#elif defined(__AVX2__) value += DotU4U8(weightWalk + i * m / 2, inputWalk, m); j += m; #endif @@ -1065,7 +1064,7 @@ namespace fastllm { sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0])); } value += sum0[0] + sum0[1] + sum0[2] + sum0[3]; -#elif defined(__AVX__) +#elif defined(__AVX2__) value += DotU4U8(weightWalk + i * m / 2, inputWalk, m); j += m; #endif diff --git a/src/fastllm.cpp b/src/fastllm.cpp index a5eade2b..a205e693 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -561,7 +561,7 @@ namespace fastllm { weightSum.resize(n); for (int i = 0; i < n; i++) { int j = 0; -#ifdef __AVX__ +#ifdef __AVX2__ __m256i acc = _mm256_setzero_si256(); const __m256i ones = _mm256_set1_epi16(1); for (; j + 31 < m; j += 32) { @@ -607,7 +607,7 @@ namespace fastllm { } weightSum[i] += sum0[0] + sum0[1] + sum0[2] + sum0[3]; #endif -#ifdef __AVX__ +#ifdef __AVX2__ __m256i acc = _mm256_setzero_si256(); const __m256i lowMask = _mm256_set1_epi8(0xf); const __m256i ones = _mm256_set1_epi16(1);