diff --git a/src/amalgam/gen/avx512skx.c b/src/amalgam/gen/avx512skx.c index 23c44656554..c473078acf7 100644 --- a/src/amalgam/gen/avx512skx.c +++ b/src/amalgam/gen/avx512skx.c @@ -3679,7 +3679,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -3730,22 +3729,20 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { @@ -3825,7 +3822,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -3984,37 +3980,37 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -4026,28 +4022,26 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { @@ -4155,13 +4149,13 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); @@ -4423,37 +4417,37 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -5582,7 +5576,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -5635,22 +5628,20 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { @@ -5729,7 +5720,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -5890,37 +5880,37 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -5932,28 +5922,26 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { @@ -6062,13 +6050,13 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); @@ -6331,37 +6319,37 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -7507,7 +7495,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -7558,22 +7545,20 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { @@ -7654,7 +7639,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -7813,37 +7797,37 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); @@ -7855,28 +7839,26 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - a2 = (const uint8_t*) ((uintptr_t) a2 - kc); - a3 = (const uint8_t*) ((uintptr_t) a3 - kc); - a4 = (const uint8_t*) ((uintptr_t) a4 - kc); - a5 = (const uint8_t*) ((uintptr_t) a5 - kc); - a6 = (const uint8_t*) ((uintptr_t) a6 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const uint8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const uint8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (uint8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const uint8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (uint8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const uint8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (uint8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const uint8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { @@ -7985,13 +7967,13 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); @@ -8254,37 +8236,37 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/amalgam/gen/avx512vnni.c b/src/amalgam/gen/avx512vnni.c index 485ebb62656..56a1206e484 100644 --- a/src/amalgam/gen/avx512vnni.c +++ b/src/amalgam/gen/avx512vnni.c @@ -1376,7 +1376,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -1436,22 +1435,20 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { @@ -1531,7 +1528,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -1711,37 +1707,37 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -1753,28 +1749,26 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { @@ -1825,7 +1819,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -1896,13 +1889,13 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); @@ -1975,7 +1968,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -2190,37 +2182,37 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -2291,7 +2283,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -2353,22 +2344,20 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { @@ -2447,7 +2436,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -2629,37 +2617,37 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -2671,28 +2659,26 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { @@ -2742,7 +2728,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -2815,13 +2800,13 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); @@ -2893,7 +2878,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -3110,37 +3094,37 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/microparams-init.c b/src/microparams-init.c index 609ddb0168d..72600870bdf 100644 --- a/src/microparams-init.c +++ b/src/microparams-init.c @@ -395,10 +395,6 @@ size_t xnn_init_qs8_qc8w_conv_minmax_fp32_avx512vnni_params( for (uint32_t i = 0; i < 16; i++) { params->fp32_avx512vnni.output_zero_point[i] = output_zero_point; } - const int8_t control_mask[16] = {0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15}; - for (uint32_t i = 0; i < 16; i++) { - params->fp32_avx512vnni.shuffle_control_mask[i] = control_mask[i]; - } for (uint32_t i = 0; i < 16; i++) { params->fp32_avx512vnni.output_min[i] = output_min; } @@ -425,10 +421,6 @@ size_t xnn_init_qs8_conv_minmax_fp32_avx512vnni_params( for (uint32_t i = 0; i < 16; i++) { params->fp32_avx512vnni.output_zero_point[i] = output_zero_point; } - const int8_t control_mask[16] = {0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15}; - for (uint32_t i = 0; i < 16; i++) { - params->fp32_avx512vnni.shuffle_control_mask[i] = control_mask[i]; - } for (uint32_t i = 0; i < 16; i++) { params->fp32_avx512vnni.output_min[i] = output_min; } diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni-prfm.c index a36f723d6ee..deb3eaa1478 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni-prfm.c @@ -48,7 +48,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -63,8 +63,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); xnn_prefetch_to_l1((const int8_t*) w + 960); @@ -73,7 +73,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -81,17 +81,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni.c index 3d6fb997fb4..7a018c84b30 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni.c @@ -47,7 +47,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -62,8 +62,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc1x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0x0123456789ABCDEF, va0x4567, vb0123456789ABCDEFx4567); @@ -71,7 +71,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -79,16 +79,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x16c4-minmax-avx512vnni-prfm.c index 296ec3d9f0f..af3ea06e102 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x16c4-minmax-avx512vnni-prfm.c @@ -55,7 +55,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni_prfm( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -75,8 +75,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -87,8 +87,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni_prfm( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -98,20 +98,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x16c4-minmax-avx512vnni.c index feaa34bc819..2c316f8b55e 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x16c4-minmax-avx512vnni.c @@ -54,7 +54,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -74,8 +74,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -85,8 +85,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -96,19 +96,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x16c4-minmax-avx512vnni-prfm.c index 63986f6f0fa..f9f685333b3 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x16c4-minmax-avx512vnni-prfm.c @@ -62,7 +62,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__avx512vnni_prfm( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -87,8 +87,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -101,9 +101,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__avx512vnni_prfm( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -115,23 +115,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - vacc0x2x0123456789ABCDEF = _mm512_srai_epi32(vacc0x2x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + vacc2x0123456789ABCDEF = _mm512_srai_epi32(vacc2x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x16c4-minmax-avx512vnni.c index f2396abf325..13bdf05771a 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x16c4-minmax-avx512vnni.c @@ -61,7 +61,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__avx512vnni( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -86,8 +86,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__avx512vnni( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -99,9 +99,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__avx512vnni( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -113,22 +113,22 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__avx512vnni( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - vacc0x2x0123456789ABCDEF = _mm512_srai_epi32(vacc0x2x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + vacc2x0123456789ABCDEF = _mm512_srai_epi32(vacc2x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni-prfm.c index 8af518af02d..9d0e7adbea1 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni-prfm.c @@ -69,7 +69,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -99,8 +99,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -115,10 +115,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -132,26 +132,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - vacc0x2x0123456789ABCDEF = _mm512_srai_epi32(vacc0x2x0123456789ABCDEF, 4); - vacc0x3x0123456789ABCDEF = _mm512_srai_epi32(vacc0x3x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + vacc2x0123456789ABCDEF = _mm512_srai_epi32(vacc2x0123456789ABCDEF, 4); + vacc3x0123456789ABCDEF = _mm512_srai_epi32(vacc3x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni.c index 5e4cedb8b3f..292dffc68b3 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni.c @@ -68,7 +68,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -98,8 +98,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -113,10 +113,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -130,25 +130,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - vacc0x2x0123456789ABCDEF = _mm512_srai_epi32(vacc0x2x0123456789ABCDEF, 4); - vacc0x3x0123456789ABCDEF = _mm512_srai_epi32(vacc0x3x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + vacc2x0123456789ABCDEF = _mm512_srai_epi32(vacc2x0123456789ABCDEF, 4); + vacc3x0123456789ABCDEF = _mm512_srai_epi32(vacc3x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni-prfm.c index 2ec557bef4b..84fe4766c23 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni-prfm.c @@ -76,7 +76,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -111,8 +111,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -129,11 +129,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -149,29 +149,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - vacc0x2x0123456789ABCDEF = _mm512_srai_epi32(vacc0x2x0123456789ABCDEF, 4); - vacc0x3x0123456789ABCDEF = _mm512_srai_epi32(vacc0x3x0123456789ABCDEF, 4); - vacc0x4x0123456789ABCDEF = _mm512_srai_epi32(vacc0x4x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + vacc2x0123456789ABCDEF = _mm512_srai_epi32(vacc2x0123456789ABCDEF, 4); + vacc3x0123456789ABCDEF = _mm512_srai_epi32(vacc3x0123456789ABCDEF, 4); + vacc4x0123456789ABCDEF = _mm512_srai_epi32(vacc4x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni.c index f93e75a105b..4faf495735e 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni.c @@ -75,7 +75,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -110,8 +110,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -127,11 +127,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -147,28 +147,28 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - vacc0x2x0123456789ABCDEF = _mm512_srai_epi32(vacc0x2x0123456789ABCDEF, 4); - vacc0x3x0123456789ABCDEF = _mm512_srai_epi32(vacc0x3x0123456789ABCDEF, 4); - vacc0x4x0123456789ABCDEF = _mm512_srai_epi32(vacc0x4x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + vacc2x0123456789ABCDEF = _mm512_srai_epi32(vacc2x0123456789ABCDEF, 4); + vacc3x0123456789ABCDEF = _mm512_srai_epi32(vacc3x0123456789ABCDEF, 4); + vacc4x0123456789ABCDEF = _mm512_srai_epi32(vacc4x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x16c4-minmax-avx512vnni-prfm.c index e94de7dff2b..41463381bb6 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x16c4-minmax-avx512vnni-prfm.c @@ -83,7 +83,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x16c4__avx512vnni_prfm( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -123,8 +123,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -143,12 +143,12 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x16c4__avx512vnni_prfm( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -166,32 +166,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - vacc0x2x0123456789ABCDEF = _mm512_srai_epi32(vacc0x2x0123456789ABCDEF, 4); - vacc0x3x0123456789ABCDEF = _mm512_srai_epi32(vacc0x3x0123456789ABCDEF, 4); - vacc0x4x0123456789ABCDEF = _mm512_srai_epi32(vacc0x4x0123456789ABCDEF, 4); - vacc0x5x0123456789ABCDEF = _mm512_srai_epi32(vacc0x5x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + vacc2x0123456789ABCDEF = _mm512_srai_epi32(vacc2x0123456789ABCDEF, 4); + vacc3x0123456789ABCDEF = _mm512_srai_epi32(vacc3x0123456789ABCDEF, 4); + vacc4x0123456789ABCDEF = _mm512_srai_epi32(vacc4x0123456789ABCDEF, 4); + vacc5x0123456789ABCDEF = _mm512_srai_epi32(vacc5x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x16c4-minmax-avx512vnni.c index beab61828ad..9ae50c41087 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x16c4-minmax-avx512vnni.c @@ -82,7 +82,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x16c4__avx512vnni( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -122,8 +122,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x16c4__avx512vnni( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -141,12 +141,12 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x16c4__avx512vnni( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -164,31 +164,31 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x16c4__avx512vnni( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - vacc0x2x0123456789ABCDEF = _mm512_srai_epi32(vacc0x2x0123456789ABCDEF, 4); - vacc0x3x0123456789ABCDEF = _mm512_srai_epi32(vacc0x3x0123456789ABCDEF, 4); - vacc0x4x0123456789ABCDEF = _mm512_srai_epi32(vacc0x4x0123456789ABCDEF, 4); - vacc0x5x0123456789ABCDEF = _mm512_srai_epi32(vacc0x5x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + vacc2x0123456789ABCDEF = _mm512_srai_epi32(vacc2x0123456789ABCDEF, 4); + vacc3x0123456789ABCDEF = _mm512_srai_epi32(vacc3x0123456789ABCDEF, 4); + vacc4x0123456789ABCDEF = _mm512_srai_epi32(vacc4x0123456789ABCDEF, 4); + vacc5x0123456789ABCDEF = _mm512_srai_epi32(vacc5x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni-prfm.c index 6fef33644d4..08bc535fde2 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni-prfm.c @@ -90,7 +90,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -135,8 +135,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -157,13 +157,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -183,35 +183,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - vacc0x2x0123456789ABCDEF = _mm512_srai_epi32(vacc0x2x0123456789ABCDEF, 4); - vacc0x3x0123456789ABCDEF = _mm512_srai_epi32(vacc0x3x0123456789ABCDEF, 4); - vacc0x4x0123456789ABCDEF = _mm512_srai_epi32(vacc0x4x0123456789ABCDEF, 4); - vacc0x5x0123456789ABCDEF = _mm512_srai_epi32(vacc0x5x0123456789ABCDEF, 4); - vacc0x6x0123456789ABCDEF = _mm512_srai_epi32(vacc0x6x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + vacc2x0123456789ABCDEF = _mm512_srai_epi32(vacc2x0123456789ABCDEF, 4); + vacc3x0123456789ABCDEF = _mm512_srai_epi32(vacc3x0123456789ABCDEF, 4); + vacc4x0123456789ABCDEF = _mm512_srai_epi32(vacc4x0123456789ABCDEF, 4); + vacc5x0123456789ABCDEF = _mm512_srai_epi32(vacc5x0123456789ABCDEF, 4); + vacc6x0123456789ABCDEF = _mm512_srai_epi32(vacc6x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni.c index 66e7ad26797..eb998bbd75a 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni.c @@ -89,7 +89,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -134,8 +134,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -155,13 +155,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -181,34 +181,34 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - vacc0x2x0123456789ABCDEF = _mm512_srai_epi32(vacc0x2x0123456789ABCDEF, 4); - vacc0x3x0123456789ABCDEF = _mm512_srai_epi32(vacc0x3x0123456789ABCDEF, 4); - vacc0x4x0123456789ABCDEF = _mm512_srai_epi32(vacc0x4x0123456789ABCDEF, 4); - vacc0x5x0123456789ABCDEF = _mm512_srai_epi32(vacc0x5x0123456789ABCDEF, 4); - vacc0x6x0123456789ABCDEF = _mm512_srai_epi32(vacc0x6x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + vacc2x0123456789ABCDEF = _mm512_srai_epi32(vacc2x0123456789ABCDEF, 4); + vacc3x0123456789ABCDEF = _mm512_srai_epi32(vacc3x0123456789ABCDEF, 4); + vacc4x0123456789ABCDEF = _mm512_srai_epi32(vacc4x0123456789ABCDEF, 4); + vacc5x0123456789ABCDEF = _mm512_srai_epi32(vacc5x0123456789ABCDEF, 4); + vacc6x0123456789ABCDEF = _mm512_srai_epi32(vacc6x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni-prfm.c index 679d851b2e7..5528667dad8 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni-prfm.c @@ -97,7 +97,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -147,8 +147,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -171,14 +171,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -200,38 +200,38 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - vacc0x2x0123456789ABCDEF = _mm512_srai_epi32(vacc0x2x0123456789ABCDEF, 4); - vacc0x3x0123456789ABCDEF = _mm512_srai_epi32(vacc0x3x0123456789ABCDEF, 4); - vacc0x4x0123456789ABCDEF = _mm512_srai_epi32(vacc0x4x0123456789ABCDEF, 4); - vacc0x5x0123456789ABCDEF = _mm512_srai_epi32(vacc0x5x0123456789ABCDEF, 4); - vacc0x6x0123456789ABCDEF = _mm512_srai_epi32(vacc0x6x0123456789ABCDEF, 4); - vacc0x7x0123456789ABCDEF = _mm512_srai_epi32(vacc0x7x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); - __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x7x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + vacc2x0123456789ABCDEF = _mm512_srai_epi32(vacc2x0123456789ABCDEF, 4); + vacc3x0123456789ABCDEF = _mm512_srai_epi32(vacc3x0123456789ABCDEF, 4); + vacc4x0123456789ABCDEF = _mm512_srai_epi32(vacc4x0123456789ABCDEF, 4); + vacc5x0123456789ABCDEF = _mm512_srai_epi32(vacc5x0123456789ABCDEF, 4); + vacc6x0123456789ABCDEF = _mm512_srai_epi32(vacc6x0123456789ABCDEF, 4); + vacc7x0123456789ABCDEF = _mm512_srai_epi32(vacc7x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); + __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc7x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni.c index 4aad334b26b..a7d5d4857f3 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni.c @@ -96,7 +96,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni( const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 do { const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); @@ -146,8 +146,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni( const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); @@ -169,14 +169,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni( w = (const int8_t*) w + 64; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -198,37 +198,37 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni( const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4); - vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4); - vacc0x2x0123456789ABCDEF = _mm512_srai_epi32(vacc0x2x0123456789ABCDEF, 4); - vacc0x3x0123456789ABCDEF = _mm512_srai_epi32(vacc0x3x0123456789ABCDEF, 4); - vacc0x4x0123456789ABCDEF = _mm512_srai_epi32(vacc0x4x0123456789ABCDEF, 4); - vacc0x5x0123456789ABCDEF = _mm512_srai_epi32(vacc0x5x0123456789ABCDEF, 4); - vacc0x6x0123456789ABCDEF = _mm512_srai_epi32(vacc0x6x0123456789ABCDEF, 4); - vacc0x7x0123456789ABCDEF = _mm512_srai_epi32(vacc0x7x0123456789ABCDEF, 4); - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); - __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x7x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); + vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); + vacc2x0123456789ABCDEF = _mm512_srai_epi32(vacc2x0123456789ABCDEF, 4); + vacc3x0123456789ABCDEF = _mm512_srai_epi32(vacc3x0123456789ABCDEF, 4); + vacc4x0123456789ABCDEF = _mm512_srai_epi32(vacc4x0123456789ABCDEF, 4); + vacc5x0123456789ABCDEF = _mm512_srai_epi32(vacc5x0123456789ABCDEF, 4); + vacc6x0123456789ABCDEF = _mm512_srai_epi32(vacc6x0123456789ABCDEF, 4); + vacc7x0123456789ABCDEF = _mm512_srai_epi32(vacc7x0123456789ABCDEF, 4); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); + __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc7x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni-prfm.c index 92d589f547c..268ed16ddbf 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni-prfm.c @@ -71,7 +71,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -80,13 +80,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni.c index baf95a264ee..094c28a6b75 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni.c @@ -68,7 +68,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -76,13 +76,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16c4-minmax-avx512vnni-prfm.c index 267dec6cd0b..b889e622c5e 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16c4-minmax-avx512vnni-prfm.c @@ -85,8 +85,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -97,15 +97,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16c4-minmax-avx512vnni.c index 38b064450b0..f471216af8a 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16c4-minmax-avx512vnni.c @@ -82,8 +82,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -93,15 +93,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16c4-minmax-avx512vnni-prfm.c index 780fc0bc563..bfea4bae9ad 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16c4-minmax-avx512vnni-prfm.c @@ -99,9 +99,9 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -114,17 +114,17 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16c4-minmax-avx512vnni.c index e2154861077..17adf7e29da 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16c4-minmax-avx512vnni.c @@ -96,9 +96,9 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -110,17 +110,17 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni-prfm.c index 3220ebf5ad5..ef7b3ba78b7 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni-prfm.c @@ -113,10 +113,10 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -131,19 +131,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni.c index 6023ff4e184..684fbe2483c 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni.c @@ -110,10 +110,10 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -127,19 +127,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni-prfm.c index fbcf9cdff13..d939e3e686b 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni-prfm.c @@ -127,11 +127,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -148,21 +148,21 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni.c index 66ad0baf5f6..0f8a397da66 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni.c @@ -124,11 +124,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -144,21 +144,21 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16c4-minmax-avx512vnni-prfm.c index 661ad91e7e9..97e3e96bfda 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16c4-minmax-avx512vnni-prfm.c @@ -141,12 +141,12 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -165,23 +165,23 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16c4-minmax-avx512vnni.c index 16275913fdc..7c2a863826a 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16c4-minmax-avx512vnni.c @@ -138,12 +138,12 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -161,23 +161,23 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni-prfm.c index 2f9f4413853..b2750231cc2 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni-prfm.c @@ -155,13 +155,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -182,25 +182,25 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni.c index daf0fe67e86..291b6e3d60a 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni.c @@ -152,13 +152,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -178,25 +178,25 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni-prfm.c index 1322e73c90b..a434395035a 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni-prfm.c @@ -169,14 +169,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -199,27 +199,27 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); - __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x7x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); + __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc7x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni.c index b4998b507cc..37e0f545482 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni.c @@ -166,14 +166,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -195,27 +195,27 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); - __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x7x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); + __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc7x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); diff --git a/src/qs8-gemm/MRx16c4-avx512vnni.c.in b/src/qs8-gemm/MRx16c4-avx512vnni.c.in index 2711ae7ee79..332325fa12b 100644 --- a/src/qs8-gemm/MRx16c4-avx512vnni.c.in +++ b/src/qs8-gemm/MRx16c4-avx512vnni.c.in @@ -80,7 +80,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4__ const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 $if DATATYPE == "QC4": - const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 + const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0 $else: const __m512i vsign_mask =_mm512_set1_epi8(params->${PARAMS_STRUCT}.sign_mask); // 0x80 $if DATATYPE != "QC8": @@ -88,9 +88,8 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4__ const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->${PARAMS_STRUCT}.shuffle_control_mask); - $if DATATYPE == "QU8": - const __m512i vb_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.kernel_zero_point); + $if DATATYPE == "QU8": + const __m512i vb_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.kernel_zero_point); do { $if DATATYPE in ["QD8", "QC4"]: const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); @@ -115,8 +114,8 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4__ $if DATATYPE == "QC4": const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4); - const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask); - const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask); + const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask); + const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask); $else: const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); const __m512i vb0123456789ABCDEFx4567 = _mm512_load_si512((const ${XINT8_T}*) w + 64); @@ -137,7 +136,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4__ k -= 8 * sizeof(${XINT8_T}); } $for M in range(MR): - vacc0x${M}x0123456789ABCDEF = _mm512_add_epi32(vacc0x${M}x0123456789ABCDEF, vacc1x${M}x0123456789ABCDEF); + __m512i vacc${M}x0123456789ABCDEF = _mm512_add_epi32(vacc0x${M}x0123456789ABCDEF, vacc1x${M}x0123456789ABCDEF); if (k != 0) { $for M in range(MR): @@ -147,14 +146,14 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4__ $if DATATYPE == "QC4": const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4); - const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask); + const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask); $else: const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); $if PREFETCH: xnn_prefetch_to_l1((const ${XINT8_T}*) w + 960); $for M in range(MR): - vacc0x${M}x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x${M}x0123456789ABCDEF, va${M}x0123, vb0123456789ABCDEF); + vacc${M}x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc${M}x0123456789ABCDEF, va${M}x0123, vb0123456789ABCDEF); w = (const ${XINT8_T}*) w + 64; k -= 4 * sizeof(${XINT8_T}); @@ -162,9 +161,9 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4__ $if DATATYPE == "QC4": $for M in range(MR): - vacc0x${M}x0123456789ABCDEF = _mm512_srai_epi32(vacc0x${M}x0123456789ABCDEF, 4); + vacc${M}x0123456789ABCDEF = _mm512_srai_epi32(vacc${M}x0123456789ABCDEF, 4); $for M in range(MR): - __m512 vscaled${M}x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x${M}x0123456789ABCDEF); + __m512 vscaled${M}x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc${M}x0123456789ABCDEF); $if DATATYPE in ["QD8", "QC4"]: $for M in range(MR): @@ -215,19 +214,19 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4__ vscaled${M}x0123456789ABCDEF = _mm512_min_ps(vscaled${M}x0123456789ABCDEF, voutput_max_less_zero_point); $for M in range(MR): - vacc0x${M}x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled${M}x0123456789ABCDEF); + vacc${M}x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled${M}x0123456789ABCDEF); $for M in range(MR): - __m256i vacc0x${M}x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x${M}x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x${M}x0123456789ABCDEF, 1)); + __m256i vacc${M}x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc${M}x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc${M}x0123456789ABCDEF, 1)); $for M in range(MR): - vacc0x${M}x012389AB4567CDEF = _mm256_adds_epi16(vacc0x${M}x012389AB4567CDEF, voutput_zero_point); + vacc${M}x012389AB4567CDEF = _mm256_adds_epi16(vacc${M}x012389AB4567CDEF, voutput_zero_point); $for M in range(MR): - const __m128i vout${M}x012389AB4567CDEF = ${_MM_PACKXS_EPI16}(_mm256_castsi256_si128(vacc0x${M}x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x${M}x012389AB4567CDEF, 1)); + const __m128i vout${M}x012389AB4567CDEF = ${_MM_PACKXS_EPI16}(_mm256_castsi256_si128(vacc${M}x012389AB4567CDEF), _mm256_extracti128_si256(vacc${M}x012389AB4567CDEF, 1)); $for M in range(MR): - __m128i vout${M}x0123456789ABCDEF = _mm_shuffle_epi8(vout${M}x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout${M}x0123456789ABCDEF = _mm_shuffle_epi32(vout${M}x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); $for M in range(MR): vout${M}x0123456789ABCDEF = ${_MM_MAX_EPX8}(vout${M}x0123456789ABCDEF, voutput_min); @@ -235,12 +234,8 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4__ if (nc >= 16) { $for M in range(MR): _mm_storeu_si128((__m128i*) c${M}, vout${M}x0123456789ABCDEF); - - $for M in range(MR): - a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc); - - $for M in range(MR): c${M} = (${OUT_T}*) ((uintptr_t) c${M} + cn_stride); + a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/MRx16c8-avx512skx.c.in b/src/qs8-gemm/MRx16c8-avx512skx.c.in index fcd76fe0458..a5b5c801fa4 100644 --- a/src/qs8-gemm/MRx16c8-avx512skx.c.in +++ b/src/qs8-gemm/MRx16c8-avx512skx.c.in @@ -90,7 +90,6 @@ void xnn_${DATATYPE_SPEC}_gemm${GEMM_SUFFIX}_minmax${REQUANTIZATION_SPEC}_ukerne const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min); $if DATATYPE == "QU8": const __m512i vb_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { $if DATATYPE in ["QD8", "QC4"]: @@ -256,16 +255,16 @@ void xnn_${DATATYPE_SPEC}_gemm${GEMM_SUFFIX}_minmax${REQUANTIZATION_SPEC}_ukerne vacc${M}x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled${M}x0123456789ABCDEF); $for M in range(MR): - __m256i vacc${M}x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc${M}x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc${M}x0123456789ABCDEF, 1)); + __m256i vacc${M}x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc${M}x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc${M}x0123456789ABCDEF, 1)); $for M in range(MR): - vacc${M}x0123456789AB4567CDEF = _mm256_adds_epi16(vacc${M}x0123456789AB4567CDEF, voutput_zero_point); + vacc${M}x012389AB4567CDEF = _mm256_adds_epi16(vacc${M}x012389AB4567CDEF, voutput_zero_point); $for M in range(MR): - const __m128i vout${M}x0123456789AB4567CDEF = ${_MM_PACKXS_EPI16}(_mm256_castsi256_si128(vacc${M}x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc${M}x0123456789AB4567CDEF, 1)); + const __m128i vout${M}x012389AB4567CDEF = ${_MM_PACKXS_EPI16}(_mm256_castsi256_si128(vacc${M}x012389AB4567CDEF), _mm256_extracti128_si256(vacc${M}x012389AB4567CDEF, 1)); $for M in range(MR): - __m128i vout${M}x0123456789ABCDEF = _mm_shuffle_epi8(vout${M}x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout${M}x0123456789ABCDEF = _mm_shuffle_epi32(vout${M}x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); $for M in range(MR): vout${M}x0123456789ABCDEF = ${_MM_MAX_EPX8}(vout${M}x0123456789ABCDEF, voutput_min); @@ -273,12 +272,8 @@ void xnn_${DATATYPE_SPEC}_gemm${GEMM_SUFFIX}_minmax${REQUANTIZATION_SPEC}_ukerne if (nc >= 16) { $for M in range(MR): _mm_storeu_si128((__m128i*) c${M}, vout${M}x0123456789ABCDEF); - - $for M in range(MR): - a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc); - - $for M in range(MR): c${M} = (${OUT_T}*) ((uintptr_t) c${M} + cn_stride); + a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/MRx16c8-avx512vnni.c.in b/src/qs8-gemm/MRx16c8-avx512vnni.c.in index 2fe8098a113..bdf60a4e9e6 100644 --- a/src/qs8-gemm/MRx16c8-avx512vnni.c.in +++ b/src/qs8-gemm/MRx16c8-avx512vnni.c.in @@ -88,9 +88,8 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8__ const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->${PARAMS_STRUCT}.shuffle_control_mask); - $if DATATYPE == "QU8": - const __m512i vb_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.kernel_zero_point); + $if DATATYPE == "QU8": + const __m512i vb_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.kernel_zero_point); do { $if DATATYPE in ["QD8", "QC4"]: const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); @@ -240,16 +239,16 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8__ vacc${M}x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled${M}x0123456789ABCDEF); $for M in range(MR): - __m256i vacc${M}x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc${M}x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc${M}x0123456789ABCDEF, 1)); + __m256i vacc${M}x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc${M}x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc${M}x0123456789ABCDEF, 1)); $for M in range(MR): - vacc${M}x0123456789AB4567CDEF = _mm256_adds_epi16(vacc${M}x0123456789AB4567CDEF, voutput_zero_point); + vacc${M}x012389AB4567CDEF = _mm256_adds_epi16(vacc${M}x012389AB4567CDEF, voutput_zero_point); $for M in range(MR): - const __m128i vout${M}x0123456789AB4567CDEF = ${_MM_PACKXS_EPI16}(_mm256_castsi256_si128(vacc${M}x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc${M}x0123456789AB4567CDEF, 1)); + const __m128i vout${M}x012389AB4567CDEF = ${_MM_PACKXS_EPI16}(_mm256_castsi256_si128(vacc${M}x012389AB4567CDEF), _mm256_extracti128_si256(vacc${M}x012389AB4567CDEF, 1)); $for M in range(MR): - __m128i vout${M}x0123456789ABCDEF = _mm_shuffle_epi8(vout${M}x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout${M}x0123456789ABCDEF = _mm_shuffle_epi32(vout${M}x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); $for M in range(MR): vout${M}x0123456789ABCDEF = ${_MM_MAX_EPX8}(vout${M}x0123456789ABCDEF, voutput_min); @@ -257,12 +256,8 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8__ if (nc >= 16) { $for M in range(MR): _mm_storeu_si128((__m128i*) c${M}, vout${M}x0123456789ABCDEF); - - $for M in range(MR): - a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc); - - $for M in range(MR): c${M} = (${OUT_T}*) ((uintptr_t) c${M} + cn_stride); + a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-fp32-avx512vnni-prfm.c index ad48933af54..87d24cfe8a6 100644 --- a/src/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-fp32-avx512vnni-prfm.c @@ -48,7 +48,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -71,7 +70,7 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -80,36 +79,34 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-fp32-avx512vnni.c index c7ca3eafbff..f2319c8d584 100644 --- a/src/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-1x16c4-minmax-fp32-avx512vnni.c @@ -47,7 +47,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -68,7 +67,7 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -76,36 +75,34 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512skx-prfm.c index 5b4ad46feac..7c8fbaba509 100644 --- a/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512skx-prfm.c @@ -46,7 +46,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -97,22 +96,20 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512skx.c b/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512skx.c index 77e1e4170db..d75ee96808f 100644 --- a/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512skx.c @@ -45,7 +45,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -94,22 +93,20 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512vnni-prfm.c index 78af81ef20f..27ee2300beb 100644 --- a/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512vnni-prfm.c @@ -48,7 +48,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -108,22 +107,20 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512vnni.c index ffe373e5acc..6021ea48f6d 100644 --- a/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-1x16c8-minmax-fp32-avx512vnni.c @@ -47,7 +47,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -101,22 +100,20 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-2x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-2x16c4-minmax-fp32-avx512vnni-prfm.c index b6e42c45349..da33aaece61 100644 --- a/src/qs8-gemm/gen/qs8-gemm-2x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-2x16c4-minmax-fp32-avx512vnni-prfm.c @@ -54,7 +54,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -84,8 +83,8 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -96,15 +95,15 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -112,33 +111,31 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni_prfm( vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max_less_zero_point); vscaled1x0123456789ABCDEF = _mm512_min_ps(vscaled1x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-2x16c4-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-2x16c4-minmax-fp32-avx512vnni.c index 7d214c8a000..d344a7bca78 100644 --- a/src/qs8-gemm/gen/qs8-gemm-2x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-2x16c4-minmax-fp32-avx512vnni.c @@ -53,7 +53,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -81,8 +80,8 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -92,15 +91,15 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -108,33 +107,31 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni( vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max_less_zero_point); vscaled1x0123456789ABCDEF = _mm512_min_ps(vscaled1x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512skx-prfm.c index 3e17fcce0b5..a811c24cf96 100644 --- a/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512skx-prfm.c @@ -52,7 +52,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -121,30 +120,28 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512skx.c b/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512skx.c index 846ad724b7f..459c037af8f 100644 --- a/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512skx.c @@ -51,7 +51,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -118,30 +117,28 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512vnni-prfm.c index 6eb2fb5f5c1..26facf1c1af 100644 --- a/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512vnni-prfm.c @@ -54,7 +54,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -134,30 +133,28 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c8__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512vnni.c index 40e7a94fb50..b198397570e 100644 --- a/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-2x16c8-minmax-fp32-avx512vnni.c @@ -53,7 +53,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -127,30 +126,28 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x16c8__avx512vnni( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-3x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-3x16c4-minmax-fp32-avx512vnni-prfm.c index ad84c753de0..50d9690ddbd 100644 --- a/src/qs8-gemm/gen/qs8-gemm-3x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-3x16c4-minmax-fp32-avx512vnni-prfm.c @@ -60,7 +60,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -97,9 +96,9 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -112,17 +111,17 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -132,25 +131,25 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( vscaled1x0123456789ABCDEF = _mm512_min_ps(vscaled1x0123456789ABCDEF, voutput_max_less_zero_point); vscaled2x0123456789ABCDEF = _mm512_min_ps(vscaled2x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -158,16 +157,14 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-3x16c4-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-3x16c4-minmax-fp32-avx512vnni.c index 226b8cc1412..6fa5ec1f59e 100644 --- a/src/qs8-gemm/gen/qs8-gemm-3x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-3x16c4-minmax-fp32-avx512vnni.c @@ -59,7 +59,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -94,9 +93,9 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -108,17 +107,17 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -128,25 +127,25 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni( vscaled1x0123456789ABCDEF = _mm512_min_ps(vscaled1x0123456789ABCDEF, voutput_max_less_zero_point); vscaled2x0123456789ABCDEF = _mm512_min_ps(vscaled2x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -154,16 +153,14 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512skx-prfm.c index 627cdc660a2..1768b5360f2 100644 --- a/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512skx-prfm.c @@ -58,7 +58,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -145,21 +144,21 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx_prfm( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -167,16 +166,14 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512skx.c b/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512skx.c index a2d593138b7..b99ca0c848a 100644 --- a/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512skx.c @@ -57,7 +57,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -142,21 +141,21 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -164,16 +163,14 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512vnni-prfm.c index 253a5c1cd22..e331308dfcc 100644 --- a/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512vnni-prfm.c @@ -60,7 +60,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -160,21 +159,21 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c8__avx512vnni_prfm( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -182,16 +181,14 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512vnni.c index 9e2656dc2c8..cdbe49f0229 100644 --- a/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-3x16c8-minmax-fp32-avx512vnni.c @@ -59,7 +59,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -153,21 +152,21 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c8__avx512vnni( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -175,16 +174,14 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_3x16c8__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-avx512vnni-prfm.c index 10d94c1485e..12c9189a027 100644 --- a/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-avx512vnni-prfm.c @@ -66,7 +66,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -110,10 +109,10 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -128,19 +127,19 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -152,30 +151,30 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( vscaled2x0123456789ABCDEF = _mm512_min_ps(vscaled2x0123456789ABCDEF, voutput_max_less_zero_point); vscaled3x0123456789ABCDEF = _mm512_min_ps(vscaled3x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -184,19 +183,17 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-avx512vnni.c index 26e9f83dc73..10691edfe46 100644 --- a/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-4x16c4-minmax-fp32-avx512vnni.c @@ -65,7 +65,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -107,10 +106,10 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -124,19 +123,19 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -148,30 +147,30 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni( vscaled2x0123456789ABCDEF = _mm512_min_ps(vscaled2x0123456789ABCDEF, voutput_max_less_zero_point); vscaled3x0123456789ABCDEF = _mm512_min_ps(vscaled3x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -180,19 +179,17 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512skx-prfm.c index 370df3ea60f..ec42da582cf 100644 --- a/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512skx-prfm.c @@ -64,7 +64,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -169,25 +168,25 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx_prfm( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -196,19 +195,17 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512skx.c b/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512skx.c index 2931468c65a..44b1270c4e8 100644 --- a/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512skx.c @@ -63,7 +63,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -166,25 +165,25 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -193,19 +192,17 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512vnni-prfm.c index 986eea96631..44fc41ea9d4 100644 --- a/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512vnni-prfm.c @@ -66,7 +66,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -186,25 +185,25 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512vnni_prfm( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -213,19 +212,17 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512vnni.c index 8973c78558d..37984791706 100644 --- a/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-4x16c8-minmax-fp32-avx512vnni.c @@ -65,7 +65,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -179,25 +178,25 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512vnni( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -206,19 +205,17 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-5x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-5x16c4-minmax-fp32-avx512vnni-prfm.c index eec90854e83..26d4344d6ce 100644 --- a/src/qs8-gemm/gen/qs8-gemm-5x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-5x16c4-minmax-fp32-avx512vnni-prfm.c @@ -72,7 +72,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -123,11 +122,11 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -144,21 +143,21 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -172,35 +171,35 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( vscaled3x0123456789ABCDEF = _mm512_min_ps(vscaled3x0123456789ABCDEF, voutput_max_less_zero_point); vscaled4x0123456789ABCDEF = _mm512_min_ps(vscaled4x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -210,22 +209,20 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-5x16c4-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-5x16c4-minmax-fp32-avx512vnni.c index 4a22981dc93..17ef303526a 100644 --- a/src/qs8-gemm/gen/qs8-gemm-5x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-5x16c4-minmax-fp32-avx512vnni.c @@ -71,7 +71,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -120,11 +119,11 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -140,21 +139,21 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -168,35 +167,35 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni( vscaled3x0123456789ABCDEF = _mm512_min_ps(vscaled3x0123456789ABCDEF, voutput_max_less_zero_point); vscaled4x0123456789ABCDEF = _mm512_min_ps(vscaled4x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -206,22 +205,20 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512skx-prfm.c index 324a94bab0d..20979bee527 100644 --- a/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512skx-prfm.c @@ -70,7 +70,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -193,29 +192,29 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -225,22 +224,20 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512skx.c b/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512skx.c index 2ab5e0fab1a..5ac3bc3e4c4 100644 --- a/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512skx.c @@ -69,7 +69,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -190,29 +189,29 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -222,22 +221,20 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512vnni-prfm.c index b71325f09aa..4e03b237858 100644 --- a/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512vnni-prfm.c @@ -72,7 +72,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -212,29 +211,29 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni_prfm( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -244,22 +243,20 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512vnni.c index d79a79198c7..dac1fd92b6c 100644 --- a/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-5x16c8-minmax-fp32-avx512vnni.c @@ -71,7 +71,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -205,29 +204,29 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -237,22 +236,20 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-6x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-6x16c4-minmax-fp32-avx512vnni-prfm.c index ab80ace8e7b..13bfe8a7a83 100644 --- a/src/qs8-gemm/gen/qs8-gemm-6x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-6x16c4-minmax-fp32-avx512vnni-prfm.c @@ -78,7 +78,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -136,12 +135,12 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -160,23 +159,23 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -192,40 +191,40 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( vscaled4x0123456789ABCDEF = _mm512_min_ps(vscaled4x0123456789ABCDEF, voutput_max_less_zero_point); vscaled5x0123456789ABCDEF = _mm512_min_ps(vscaled5x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - __m256i vacc0x5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - vacc0x5x012389AB4567CDEF = _mm256_adds_epi16(vacc0x5x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x5x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x5x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -236,25 +235,23 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-6x16c4-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-6x16c4-minmax-fp32-avx512vnni.c index fcca0fc02d6..825d1d31690 100644 --- a/src/qs8-gemm/gen/qs8-gemm-6x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-6x16c4-minmax-fp32-avx512vnni.c @@ -77,7 +77,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -133,12 +132,12 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -156,23 +155,23 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -188,40 +187,40 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni( vscaled4x0123456789ABCDEF = _mm512_min_ps(vscaled4x0123456789ABCDEF, voutput_max_less_zero_point); vscaled5x0123456789ABCDEF = _mm512_min_ps(vscaled5x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - __m256i vacc0x5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - vacc0x5x012389AB4567CDEF = _mm256_adds_epi16(vacc0x5x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x5x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x5x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -232,25 +231,23 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512skx-prfm.c index 86eef67d0aa..0149a8eaa58 100644 --- a/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512skx-prfm.c @@ -76,7 +76,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -217,33 +216,33 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c8__avx512skx_prfm( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -254,25 +253,23 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512skx.c b/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512skx.c index 6dfc7313a96..6cec5bbb974 100644 --- a/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512skx.c @@ -75,7 +75,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -214,33 +213,33 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c8__avx512skx( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -251,25 +250,23 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512vnni-prfm.c index 75afa259288..13ec8576ca1 100644 --- a/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512vnni-prfm.c @@ -78,7 +78,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -238,33 +237,33 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c8__avx512vnni_prfm( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -275,25 +274,23 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512vnni.c index 7a47a45f92f..544b4e4ef5b 100644 --- a/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-6x16c8-minmax-fp32-avx512vnni.c @@ -77,7 +77,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -231,33 +230,33 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c8__avx512vnni( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -268,25 +267,23 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_6x16c8__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-7x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-7x16c4-minmax-fp32-avx512vnni-prfm.c index 5648d4613c8..46a93ae24a7 100644 --- a/src/qs8-gemm/gen/qs8-gemm-7x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-7x16c4-minmax-fp32-avx512vnni-prfm.c @@ -84,7 +84,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -149,13 +148,13 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -176,25 +175,25 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -212,45 +211,45 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( vscaled5x0123456789ABCDEF = _mm512_min_ps(vscaled5x0123456789ABCDEF, voutput_max_less_zero_point); vscaled6x0123456789ABCDEF = _mm512_min_ps(vscaled6x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - __m256i vacc0x5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x5x0123456789ABCDEF, 1)); - __m256i vacc0x6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - vacc0x5x012389AB4567CDEF = _mm256_adds_epi16(vacc0x5x012389AB4567CDEF, voutput_zero_point); - vacc0x6x012389AB4567CDEF = _mm256_adds_epi16(vacc0x6x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x5x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x5x012389AB4567CDEF, 1)); - const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x6x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x6x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -262,28 +261,26 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-7x16c4-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-7x16c4-minmax-fp32-avx512vnni.c index 1c7ab075cea..7396846fa94 100644 --- a/src/qs8-gemm/gen/qs8-gemm-7x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-7x16c4-minmax-fp32-avx512vnni.c @@ -83,7 +83,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -146,13 +145,13 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -172,25 +171,25 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -208,45 +207,45 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni( vscaled5x0123456789ABCDEF = _mm512_min_ps(vscaled5x0123456789ABCDEF, voutput_max_less_zero_point); vscaled6x0123456789ABCDEF = _mm512_min_ps(vscaled6x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - __m256i vacc0x5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x5x0123456789ABCDEF, 1)); - __m256i vacc0x6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - vacc0x5x012389AB4567CDEF = _mm256_adds_epi16(vacc0x5x012389AB4567CDEF, voutput_zero_point); - vacc0x6x012389AB4567CDEF = _mm256_adds_epi16(vacc0x6x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x5x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x5x012389AB4567CDEF, 1)); - const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x6x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x6x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -258,28 +257,26 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512skx-prfm.c index cabc05176c7..f467679ba2a 100644 --- a/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512skx-prfm.c @@ -82,7 +82,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -241,37 +240,37 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -283,28 +282,26 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512skx.c b/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512skx.c index a880bc3c20a..971bc67e060 100644 --- a/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512skx.c @@ -81,7 +81,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -238,37 +237,37 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -280,28 +279,26 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512vnni-prfm.c index 1f49e08d321..accaaf4eea4 100644 --- a/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512vnni-prfm.c @@ -84,7 +84,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -264,37 +263,37 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -306,28 +305,26 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512vnni.c index b079502541b..3c5a0bc6b6d 100644 --- a/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-7x16c8-minmax-fp32-avx512vnni.c @@ -83,7 +83,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -257,37 +256,37 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -299,28 +298,26 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-8x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-8x16c4-minmax-fp32-avx512vnni-prfm.c index c4301a13087..39b039a012d 100644 --- a/src/qs8-gemm/gen/qs8-gemm-8x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-8x16c4-minmax-fp32-avx512vnni-prfm.c @@ -90,7 +90,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -162,14 +161,14 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -192,27 +191,27 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); - __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x7x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); + __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc7x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -232,50 +231,50 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( vscaled6x0123456789ABCDEF = _mm512_min_ps(vscaled6x0123456789ABCDEF, voutput_max_less_zero_point); vscaled7x0123456789ABCDEF = _mm512_min_ps(vscaled7x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); + vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - __m256i vacc0x5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x5x0123456789ABCDEF, 1)); - __m256i vacc0x6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x6x0123456789ABCDEF, 1)); - __m256i vacc0x7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - vacc0x5x012389AB4567CDEF = _mm256_adds_epi16(vacc0x5x012389AB4567CDEF, voutput_zero_point); - vacc0x6x012389AB4567CDEF = _mm256_adds_epi16(vacc0x6x012389AB4567CDEF, voutput_zero_point); - vacc0x7x012389AB4567CDEF = _mm256_adds_epi16(vacc0x7x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x5x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x5x012389AB4567CDEF, 1)); - const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x6x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x6x012389AB4567CDEF, 1)); - const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x7x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x7x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -288,31 +287,29 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-8x16c4-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-8x16c4-minmax-fp32-avx512vnni.c index edf4e2cacf2..8949d0e96c4 100644 --- a/src/qs8-gemm/gen/qs8-gemm-8x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-8x16c4-minmax-fp32-avx512vnni.c @@ -89,7 +89,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -159,14 +158,14 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -188,27 +187,27 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); - __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x7x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); + __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc7x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vscale); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vscale); @@ -228,50 +227,50 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni( vscaled6x0123456789ABCDEF = _mm512_min_ps(vscaled6x0123456789ABCDEF, voutput_max_less_zero_point); vscaled7x0123456789ABCDEF = _mm512_min_ps(vscaled7x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); + vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - __m256i vacc0x5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x5x0123456789ABCDEF, 1)); - __m256i vacc0x6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x6x0123456789ABCDEF, 1)); - __m256i vacc0x7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - vacc0x5x012389AB4567CDEF = _mm256_adds_epi16(vacc0x5x012389AB4567CDEF, voutput_zero_point); - vacc0x6x012389AB4567CDEF = _mm256_adds_epi16(vacc0x6x012389AB4567CDEF, voutput_zero_point); - vacc0x7x012389AB4567CDEF = _mm256_adds_epi16(vacc0x7x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x5x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x5x012389AB4567CDEF, 1)); - const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x6x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x6x012389AB4567CDEF, 1)); - const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x7x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x7x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -284,31 +283,29 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512skx-prfm.c index d884968d548..858dc905b43 100644 --- a/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512skx-prfm.c @@ -88,7 +88,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -265,41 +264,41 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -312,31 +311,29 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512skx.c b/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512skx.c index bb0fc7bb6aa..1863fc6e28c 100644 --- a/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512skx.c @@ -87,7 +87,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -262,41 +261,41 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -309,31 +308,29 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512vnni-prfm.c index 6befd2cbe72..a11d4260792 100644 --- a/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512vnni-prfm.c @@ -90,7 +90,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -290,41 +289,41 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni_prfm( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -337,31 +336,29 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512vnni.c b/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512vnni.c index 2818616cab6..e1fc15a039f 100644 --- a/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-gemm/gen/qs8-gemm-8x16c8-minmax-fp32-avx512vnni.c @@ -89,7 +89,6 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -283,41 +282,41 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -330,31 +329,29 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qs8-igemm/MRx16c4-avx512vnni.c.in b/src/qs8-igemm/MRx16c4-avx512vnni.c.in index 2df1292cd5c..25928a3627a 100644 --- a/src/qs8-igemm/MRx16c4-avx512vnni.c.in +++ b/src/qs8-igemm/MRx16c4-avx512vnni.c.in @@ -19,14 +19,13 @@ $if PREFETCH: $DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QD8": "qd8_f32_qc8w", "QS8": "qs8", "QU8": "qu8"}[DATATYPE] -$REQUANTIZATION_SPEC = "" if DATATYPE == "QD8" else "_" + REQUANTIZATION.lower() +$REQUANTIZATION_SPEC = "" if DATATYPE in ["QD8", "QC4"] else "_" + REQUANTIZATION.lower() $PARAMS_STRUCT = REQUANTIZATION.lower() + "_avx512vnni" if REQUANTIZATION else "avx512vnni" $PARAMS_UNION = {"QC8": "xnn_qs8_qc8w_conv_minmax_params", "QD8": "xnn_f32_minmax_params", "QS8": "xnn_qs8_conv_minmax_params", "QU8": "xnn_qu8_conv_minmax_params"}[DATATYPE] $XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" -$OUT_T = "float" if DATATYPE == "QD8" else XINT8_T +$OUT_T = "float" if DATATYPE in ["QD8", "QC4"] else XINT8_T $_MM_PACKXS_EPI16 = "_mm_packus_epi16" if DATATYPE == "QU8" else "_mm_packs_epi16" $_MM_MAX_EPX8 = "_mm_max_epu8" if DATATYPE == "QU8" else "_mm_max_epi8" -$_MM_CVTEPX8_EPI16 = "_mm_cvtepu8_epi16" if DATATYPE == "QU8" else "_mm_cvtepi8_epi16" void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4__avx512vnni${"_prfm" if PREFETCH else ""}( size_t mr, size_t nc, @@ -39,7 +38,7 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4_ size_t cn_stride, size_t a_offset, const ${XINT8_T}* zero, - $if DATATYPE == "QD8": + $if DATATYPE in ["QD8", "QC4"]: const int8_t* zero_data, const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)], const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS @@ -79,7 +78,7 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4_ const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80 $if DATATYPE == "QC4": - const __m256i vmask = _mm256_set1_epi8(params->avx512vnni.mask); + const __m256i vmask = _mm256_set1_epi8(params->avx512vnni.mask); // 0xF0 $else: const __m512i vsign_mask = _mm512_set1_epi8(params->${PARAMS_STRUCT}.sign_mask); // 0x80 $if DATATYPE != "QC8": @@ -89,7 +88,6 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4_ const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min); $if DATATYPE == "QU8": const __m512i vb_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->${PARAMS_STRUCT}.shuffle_control_mask); do { $if DATATYPE in ["QD8", "QC4"]: const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); @@ -113,8 +111,6 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4_ a += ${MR}; size_t k = kc; - $if DATATYPE == "QU8": - const __m512i vb_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.kernel_zero_point); while (k >= 8 * sizeof(int8_t)) { $for M in range(MR): const __m512i va${M}x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a${M})), vsign_mask); @@ -206,16 +202,16 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4_ vacc${M}x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled${M}x0123456789ABCDEF); $for M in range(MR): - __m256i vacc${M}x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc${M}x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc${M}x0123456789ABCDEF, 1)); + __m256i vacc${M}x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc${M}x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc${M}x0123456789ABCDEF, 1)); $for M in range(MR): - vacc${M}x0123456789AB4567CDEF = _mm256_adds_epi16(vacc${M}x0123456789AB4567CDEF, voutput_zero_point); + vacc${M}x012389AB4567CDEF = _mm256_adds_epi16(vacc${M}x012389AB4567CDEF, voutput_zero_point); $for M in range(MR): - const __m128i vout${M}x0123456789AB4567CDEF = ${_MM_PACKXS_EPI16}(_mm256_castsi256_si128(vacc${M}x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc${M}x0123456789AB4567CDEF, 1)); + const __m128i vout${M}x012389AB4567CDEF = ${_MM_PACKXS_EPI16}(_mm256_castsi256_si128(vacc${M}x012389AB4567CDEF), _mm256_extracti128_si256(vacc${M}x012389AB4567CDEF, 1)); $for M in range(MR): - __m128i vout${M}x0123456789ABCDEF = _mm_shuffle_epi8(vout${M}x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout${M}x0123456789ABCDEF = _mm_shuffle_epi32(vout${M}x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); $for M in range(MR): vout${M}x0123456789ABCDEF = ${_MM_MAX_EPX8}(vout${M}x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/MRx16c8-avx512skx.c.in b/src/qs8-igemm/MRx16c8-avx512skx.c.in index bbde99c26f9..20e8e9ffdbd 100644 --- a/src/qs8-igemm/MRx16c8-avx512skx.c.in +++ b/src/qs8-igemm/MRx16c8-avx512skx.c.in @@ -85,8 +85,8 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8_ const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min); - $if DATATYPE == "QU8": - const __m512i vb_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.kernel_zero_point); + $if DATATYPE == "QU8": + const __m512i vb_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.kernel_zero_point); do { $if DATATYPE in ["QD8", "QC4"]: const __m512i vksum0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -248,16 +248,16 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8_ vacc${M}x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled${M}x0123456789ABCDEF); $for M in range(MR): - __m256i vacc${M}x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc${M}x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc${M}x0123456789ABCDEF, 1)); + __m256i vacc${M}x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc${M}x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc${M}x0123456789ABCDEF, 1)); $for M in range(MR): - vacc${M}x0123456789AB4567CDEF = _mm256_adds_epi16(vacc${M}x0123456789AB4567CDEF, voutput_zero_point); + vacc${M}x012389AB4567CDEF = _mm256_adds_epi16(vacc${M}x012389AB4567CDEF, voutput_zero_point); $for M in range(MR): - const __m128i vout${M}x0123456789AB4567CDEF = ${_MM_PACKXS_EPI16}(_mm256_castsi256_si128(vacc${M}x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc${M}x0123456789AB4567CDEF, 1)); + const __m128i vout${M}x012389AB4567CDEF = ${_MM_PACKXS_EPI16}(_mm256_castsi256_si128(vacc${M}x012389AB4567CDEF), _mm256_extracti128_si256(vacc${M}x012389AB4567CDEF, 1)); $for M in range(MR): - __m128i vout${M}x0123456789ABCDEF = _mm_shuffle_epi32(vout${M}x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout${M}x0123456789ABCDEF = _mm_shuffle_epi32(vout${M}x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); $for M in range(MR): vout${M}x0123456789ABCDEF = ${_MM_MAX_EPX8}(vout${M}x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/MRx16c8-avx512vnni.c.in b/src/qs8-igemm/MRx16c8-avx512vnni.c.in index 4e40a6bc145..2a64d2dc6c0 100644 --- a/src/qs8-igemm/MRx16c8-avx512vnni.c.in +++ b/src/qs8-igemm/MRx16c8-avx512vnni.c.in @@ -26,7 +26,6 @@ $XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" $OUT_T = "float" if DATATYPE in ["QD8", "QC4"] else XINT8_T $_MM_PACKXS_EPI16 = "_mm_packus_epi16" if DATATYPE == "QU8" else "_mm_packs_epi16" $_MM_MAX_EPX8 = "_mm_max_epu8" if DATATYPE == "QU8" else "_mm_max_epi8" -$_MM_CVTEPX8_EPI16 = "_mm_cvtepu8_epi16" if DATATYPE == "QU8" else "_mm_cvtepi8_epi16" void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8__avx512vnni${"_prfm" if PREFETCH else ""}( size_t mr, size_t nc, @@ -87,9 +86,8 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8_ const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->${PARAMS_STRUCT}.shuffle_control_mask); - $if DATATYPE == "QU8": - const __m512i vb_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.kernel_zero_point); + $if DATATYPE == "QU8": + const __m512i vb_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.kernel_zero_point); do { $if DATATYPE in ["QD8", "QC4"]: const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); @@ -248,16 +246,16 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8_ vacc${M}x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled${M}x0123456789ABCDEF); $for M in range(MR): - __m256i vacc${M}x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc${M}x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc${M}x0123456789ABCDEF, 1)); + __m256i vacc${M}x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc${M}x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc${M}x0123456789ABCDEF, 1)); $for M in range(MR): - vacc${M}x0123456789AB4567CDEF = _mm256_adds_epi16(vacc${M}x0123456789AB4567CDEF, voutput_zero_point); + vacc${M}x012389AB4567CDEF = _mm256_adds_epi16(vacc${M}x012389AB4567CDEF, voutput_zero_point); $for M in range(MR): - const __m128i vout${M}x0123456789AB4567CDEF = ${_MM_PACKXS_EPI16}(_mm256_castsi256_si128(vacc${M}x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc${M}x0123456789AB4567CDEF, 1)); + const __m128i vout${M}x012389AB4567CDEF = ${_MM_PACKXS_EPI16}(_mm256_castsi256_si128(vacc${M}x012389AB4567CDEF), _mm256_extracti128_si256(vacc${M}x012389AB4567CDEF, 1)); $for M in range(MR): - __m128i vout${M}x0123456789ABCDEF = _mm_shuffle_epi8(vout${M}x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout${M}x0123456789ABCDEF = _mm_shuffle_epi32(vout${M}x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); $for M in range(MR): vout${M}x0123456789ABCDEF = ${_MM_MAX_EPX8}(vout${M}x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-1x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-1x16c4-minmax-fp32-avx512vnni-prfm.c index e05f12e26cf..a0b5936bcd8 100644 --- a/src/qs8-igemm/gen/qs8-igemm-1x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-1x16c4-minmax-fp32-avx512vnni-prfm.c @@ -49,7 +49,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); w = (const int32_t*) w + 16; @@ -101,13 +100,13 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c4__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-1x16c4-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-1x16c4-minmax-fp32-avx512vnni.c index 57707c2fe3a..ecb705bb5ac 100644 --- a/src/qs8-igemm/gen/qs8-igemm-1x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-1x16c4-minmax-fp32-avx512vnni.c @@ -48,7 +48,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); w = (const int32_t*) w + 16; @@ -97,13 +96,13 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c4__avx512vnni( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512skx-prfm.c index 5b59e175c62..db9d6e97d75 100644 --- a/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512skx-prfm.c @@ -105,13 +105,13 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512skx.c b/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512skx.c index fb53d9030e2..82ffde0f333 100644 --- a/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512skx.c @@ -102,13 +102,13 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512vnni-prfm.c index 939f4fc92e4..c2b544f9d5f 100644 --- a/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512vnni-prfm.c @@ -49,7 +49,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -120,13 +119,13 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512vnni.c index a7d998501a3..f3504e308cf 100644 --- a/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-1x16c8-minmax-fp32-avx512vnni.c @@ -48,7 +48,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -113,13 +112,13 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-2x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-2x16c4-minmax-fp32-avx512vnni-prfm.c index 369f2e41ee1..d398cb01dff 100644 --- a/src/qs8-igemm/gen/qs8-igemm-2x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-2x16c4-minmax-fp32-avx512vnni-prfm.c @@ -53,7 +53,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_2x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -122,17 +121,17 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_2x16c4__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-2x16c4-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-2x16c4-minmax-fp32-avx512vnni.c index 9ab2a142901..5ea90366f8e 100644 --- a/src/qs8-igemm/gen/qs8-igemm-2x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-2x16c4-minmax-fp32-avx512vnni.c @@ -52,7 +52,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_2x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -118,17 +117,17 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_2x16c4__avx512vnni( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512skx-prfm.c index 4d9cebd0461..e01b8e4ef45 100644 --- a/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512skx-prfm.c @@ -131,17 +131,17 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512skx.c b/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512skx.c index 3f08eeb845e..cab4d28db26 100644 --- a/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512skx.c @@ -128,17 +128,17 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512vnni-prfm.c index 76e7ed7372d..c17ee0bfec3 100644 --- a/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512vnni-prfm.c @@ -53,7 +53,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_2x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -148,17 +147,17 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_2x16c8__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512vnni.c index 8c77349611c..bcee00b0510 100644 --- a/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-2x16c8-minmax-fp32-avx512vnni.c @@ -52,7 +52,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_2x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -141,17 +140,17 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_2x16c8__avx512vnni( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-3x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-3x16c4-minmax-fp32-avx512vnni-prfm.c index a0c64a52eb1..6e24b0c1bfb 100644 --- a/src/qs8-igemm/gen/qs8-igemm-3x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-3x16c4-minmax-fp32-avx512vnni-prfm.c @@ -57,7 +57,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -143,21 +142,21 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-3x16c4-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-3x16c4-minmax-fp32-avx512vnni.c index 644cd627c00..c7ed3d9cdc4 100644 --- a/src/qs8-igemm/gen/qs8-igemm-3x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-3x16c4-minmax-fp32-avx512vnni.c @@ -56,7 +56,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_3x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -139,21 +138,21 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_3x16c4__avx512vnni( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512skx-prfm.c index a876ef94eee..dd8c6452816 100644 --- a/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512skx-prfm.c @@ -157,21 +157,21 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx_prfm( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512skx.c b/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512skx.c index 11a9e7906bf..5390309b4a2 100644 --- a/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512skx.c @@ -154,21 +154,21 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512vnni-prfm.c index 90fea3bd7cb..6d581b67161 100644 --- a/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512vnni-prfm.c @@ -57,7 +57,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_3x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -176,21 +175,21 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_3x16c8__avx512vnni_prfm( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512vnni.c index 64b3fd80307..6a4e368f3a1 100644 --- a/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-3x16c8-minmax-fp32-avx512vnni.c @@ -56,7 +56,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_3x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -169,21 +168,21 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_3x16c8__avx512vnni( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-avx512vnni-prfm.c index faa7fd03137..467454a7a94 100644 --- a/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-avx512vnni-prfm.c @@ -61,7 +61,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -164,25 +163,25 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-avx512vnni.c index 7932cdcb776..32d4e4191e0 100644 --- a/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-4x16c4-minmax-fp32-avx512vnni.c @@ -60,7 +60,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_4x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -160,25 +159,25 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_4x16c4__avx512vnni( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512skx-prfm.c index 206cecb9c4e..d19295b62b8 100644 --- a/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512skx-prfm.c @@ -183,25 +183,25 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx_prfm( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512skx.c b/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512skx.c index c2056b637da..faf4c7ad531 100644 --- a/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512skx.c @@ -180,25 +180,25 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512vnni-prfm.c index 9aee068e0d1..27b2c99d0b5 100644 --- a/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512vnni-prfm.c @@ -61,7 +61,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_4x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -204,25 +203,25 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_4x16c8__avx512vnni_prfm( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512vnni.c index cd8ce580ee1..fe2fe046b8c 100644 --- a/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-4x16c8-minmax-fp32-avx512vnni.c @@ -60,7 +60,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_4x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -197,25 +196,25 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_4x16c8__avx512vnni( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-5x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-5x16c4-minmax-fp32-avx512vnni-prfm.c index 4ee716213a3..0697cfc37f1 100644 --- a/src/qs8-igemm/gen/qs8-igemm-5x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-5x16c4-minmax-fp32-avx512vnni-prfm.c @@ -65,7 +65,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -185,29 +184,29 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-5x16c4-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-5x16c4-minmax-fp32-avx512vnni.c index 5dd55bcc848..30385e9ad2a 100644 --- a/src/qs8-igemm/gen/qs8-igemm-5x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-5x16c4-minmax-fp32-avx512vnni.c @@ -64,7 +64,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_5x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -181,29 +180,29 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_5x16c4__avx512vnni( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512skx-prfm.c index 185bf28164a..c7a8f808aa6 100644 --- a/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512skx-prfm.c @@ -209,29 +209,29 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512skx.c b/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512skx.c index c7c7445a71f..55047400239 100644 --- a/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512skx.c @@ -206,29 +206,29 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_5x16c8__avx512skx( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512vnni-prfm.c index 5f2ff529f0b..bdd3488aa7a 100644 --- a/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512vnni-prfm.c @@ -65,7 +65,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_5x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -232,29 +231,29 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_5x16c8__avx512vnni_prfm( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512vnni.c index 8b0ccd88fd1..10821b1c9f6 100644 --- a/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-5x16c8-minmax-fp32-avx512vnni.c @@ -64,7 +64,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_5x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -225,29 +224,29 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_5x16c8__avx512vnni( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-6x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-6x16c4-minmax-fp32-avx512vnni-prfm.c index 2760bf98956..c9368b8b340 100644 --- a/src/qs8-igemm/gen/qs8-igemm-6x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-6x16c4-minmax-fp32-avx512vnni-prfm.c @@ -69,7 +69,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -206,33 +205,33 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-6x16c4-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-6x16c4-minmax-fp32-avx512vnni.c index 9e5b086b73f..0c9341f6347 100644 --- a/src/qs8-igemm/gen/qs8-igemm-6x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-6x16c4-minmax-fp32-avx512vnni.c @@ -68,7 +68,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_6x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -202,33 +201,33 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_6x16c4__avx512vnni( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512skx-prfm.c index bfe8933c097..2ff546fe795 100644 --- a/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512skx-prfm.c @@ -235,33 +235,33 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_6x16c8__avx512skx_prfm( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512skx.c b/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512skx.c index d9cf94e2144..1c4253f7736 100644 --- a/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512skx.c @@ -232,33 +232,33 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_6x16c8__avx512skx( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512vnni-prfm.c index 62cd3fcff7b..ab79d9562d2 100644 --- a/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512vnni-prfm.c @@ -69,7 +69,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_6x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -260,33 +259,33 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_6x16c8__avx512vnni_prfm( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512vnni.c index aa4adaa7460..214f24d8b11 100644 --- a/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-6x16c8-minmax-fp32-avx512vnni.c @@ -68,7 +68,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_6x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -253,33 +252,33 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_6x16c8__avx512vnni( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-7x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-7x16c4-minmax-fp32-avx512vnni-prfm.c index 631d332f6f7..830545ced22 100644 --- a/src/qs8-igemm/gen/qs8-igemm-7x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-7x16c4-minmax-fp32-avx512vnni-prfm.c @@ -73,7 +73,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -227,37 +226,37 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-7x16c4-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-7x16c4-minmax-fp32-avx512vnni.c index 441bf217736..9b6b672d9be 100644 --- a/src/qs8-igemm/gen/qs8-igemm-7x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-7x16c4-minmax-fp32-avx512vnni.c @@ -72,7 +72,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_7x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -223,37 +222,37 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_7x16c4__avx512vnni( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512skx-prfm.c index 33d980a7579..b92d9bf9c6e 100644 --- a/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512skx-prfm.c @@ -261,37 +261,37 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512skx.c b/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512skx.c index a1f43d65a58..643deafed46 100644 --- a/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512skx.c @@ -258,37 +258,37 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_7x16c8__avx512skx( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512vnni-prfm.c index dd7ed0dbf66..cc2c4cf909a 100644 --- a/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512vnni-prfm.c @@ -73,7 +73,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -288,37 +287,37 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512vnni.c index 916fe9b8a50..304ebf978fd 100644 --- a/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-7x16c8-minmax-fp32-avx512vnni.c @@ -72,7 +72,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -281,37 +280,37 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-8x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-8x16c4-minmax-fp32-avx512vnni-prfm.c index a3287ac9369..35d6e33fd78 100644 --- a/src/qs8-igemm/gen/qs8-igemm-8x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-8x16c4-minmax-fp32-avx512vnni-prfm.c @@ -77,7 +77,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -248,41 +247,41 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-8x16c4-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-8x16c4-minmax-fp32-avx512vnni.c index 6656dac3563..f0d6bf0e04e 100644 --- a/src/qs8-igemm/gen/qs8-igemm-8x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-8x16c4-minmax-fp32-avx512vnni.c @@ -76,7 +76,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_8x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -244,41 +243,41 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_8x16c4__avx512vnni( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512skx-prfm.c index d56db358e4b..66854d2f0af 100644 --- a/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512skx-prfm.c @@ -287,41 +287,41 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512skx.c b/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512skx.c index 2c7b0b4c115..68a5de91678 100644 --- a/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512skx.c @@ -284,41 +284,41 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_8x16c8__avx512skx( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512vnni-prfm.c index 5d1c045bdf9..c11b6d4a26f 100644 --- a/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512vnni-prfm.c @@ -77,7 +77,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_8x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -316,41 +315,41 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_8x16c8__avx512vnni_prfm( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512vnni.c b/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512vnni.c index a4db399caa7..acf32549ee0 100644 --- a/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-igemm/gen/qs8-igemm-8x16c8-minmax-fp32-avx512vnni.c @@ -76,7 +76,6 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_8x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -309,41 +308,41 @@ void xnn_qs8_igemm_minmax_fp32_ukernel_8x16c8__avx512vnni( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512vnni-prfm.c index 70df9572a8b..5b0d39b97f6 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512vnni-prfm.c @@ -47,7 +47,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -70,7 +69,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -79,13 +78,13 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -93,24 +92,22 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni_prfm( vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512vnni.c index faec11a1cab..27887c0193c 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512vnni.c @@ -46,7 +46,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -67,7 +66,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -75,13 +74,13 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -89,24 +88,22 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni( vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512skx-prfm.c index 10da17f0d70..b33e94bbbfd 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512skx-prfm.c @@ -45,7 +45,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -98,22 +97,20 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512skx.c index b6b8566b99b..12bda7869e0 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512skx.c @@ -44,7 +44,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -95,22 +94,20 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512skx( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512vnni-prfm.c index a7cdf62efe2..613b9c105cd 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512vnni-prfm.c @@ -47,7 +47,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -109,22 +108,20 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512vnni.c index e896287ad5a..b97d3691ef7 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512vnni.c @@ -46,7 +46,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -102,22 +101,20 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c4-minmax-fp32-avx512vnni-prfm.c index 1fb46f89a6c..12757627fd7 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c4-minmax-fp32-avx512vnni-prfm.c @@ -53,7 +53,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -83,8 +82,8 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -95,15 +94,15 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -113,33 +112,31 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni_prfm( vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max_less_zero_point); vscaled1x0123456789ABCDEF = _mm512_min_ps(vscaled1x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c4-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c4-minmax-fp32-avx512vnni.c index 1f637e3ab42..1dd66285f8e 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c4-minmax-fp32-avx512vnni.c @@ -52,7 +52,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -80,8 +79,8 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -91,15 +90,15 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -109,33 +108,31 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c4__avx512vnni( vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max_less_zero_point); vscaled1x0123456789ABCDEF = _mm512_min_ps(vscaled1x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512skx-prfm.c index 00ac45d0b6d..46b9d26d418 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512skx-prfm.c @@ -51,7 +51,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -122,30 +121,28 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512skx.c index eb382d5d3b6..eddadb5bf18 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512skx.c @@ -50,7 +50,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -119,30 +118,28 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c8__avx512skx( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512vnni-prfm.c index cd1d3f46813..e368f75e7af 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512vnni-prfm.c @@ -53,7 +53,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -135,30 +134,28 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c8__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512vnni.c index c7fe916af20..6098ec97dc4 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x16c8-minmax-fp32-avx512vnni.c @@ -52,7 +52,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -128,30 +127,28 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c8__avx512vnni( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c4-minmax-fp32-avx512vnni-prfm.c index 3a557467a14..46344bdbf0e 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c4-minmax-fp32-avx512vnni-prfm.c @@ -59,7 +59,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -96,9 +95,9 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -111,17 +110,17 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -133,25 +132,25 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( vscaled1x0123456789ABCDEF = _mm512_min_ps(vscaled1x0123456789ABCDEF, voutput_max_less_zero_point); vscaled2x0123456789ABCDEF = _mm512_min_ps(vscaled2x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -159,16 +158,14 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c4-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c4-minmax-fp32-avx512vnni.c index 5e276b5fbf4..c3196d61258 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c4-minmax-fp32-avx512vnni.c @@ -58,7 +58,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -93,9 +92,9 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -107,17 +106,17 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -129,25 +128,25 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni( vscaled1x0123456789ABCDEF = _mm512_min_ps(vscaled1x0123456789ABCDEF, voutput_max_less_zero_point); vscaled2x0123456789ABCDEF = _mm512_min_ps(vscaled2x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -155,16 +154,14 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512skx-prfm.c index fa416ca80fd..1574dce8b74 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512skx-prfm.c @@ -57,7 +57,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -146,21 +145,21 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c8__avx512skx_prfm( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -168,16 +167,14 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512skx.c index dc12ecb27c8..f2110bf8e74 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512skx.c @@ -56,7 +56,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -143,21 +142,21 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c8__avx512skx( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -165,16 +164,14 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512vnni-prfm.c index b0c8e3d7ec1..8bfc90e6d13 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512vnni-prfm.c @@ -59,7 +59,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -161,21 +160,21 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c8__avx512vnni_prfm( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -183,16 +182,14 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512vnni.c index 5c3da1a2337..b8a8f9faf0c 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-3x16c8-minmax-fp32-avx512vnni.c @@ -58,7 +58,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -154,21 +153,21 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c8__avx512vnni( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -176,16 +175,14 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c8__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c4-minmax-fp32-avx512vnni-prfm.c index dbfbf71bca7..730ca76b2d0 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c4-minmax-fp32-avx512vnni-prfm.c @@ -65,7 +65,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -109,10 +108,10 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -127,19 +126,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -153,30 +152,30 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( vscaled2x0123456789ABCDEF = _mm512_min_ps(vscaled2x0123456789ABCDEF, voutput_max_less_zero_point); vscaled3x0123456789ABCDEF = _mm512_min_ps(vscaled3x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -185,19 +184,17 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c4-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c4-minmax-fp32-avx512vnni.c index 27aa9bd7b0b..fb756924472 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c4-minmax-fp32-avx512vnni.c @@ -64,7 +64,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -106,10 +105,10 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -123,19 +122,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -149,30 +148,30 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni( vscaled2x0123456789ABCDEF = _mm512_min_ps(vscaled2x0123456789ABCDEF, voutput_max_less_zero_point); vscaled3x0123456789ABCDEF = _mm512_min_ps(vscaled3x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -181,19 +180,17 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512skx-prfm.c index a2d9751196a..2c30705e0a7 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512skx-prfm.c @@ -63,7 +63,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -170,25 +169,25 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c8__avx512skx_prfm( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -197,19 +196,17 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512skx.c index ac66689859a..84ea2027064 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512skx.c @@ -62,7 +62,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -167,25 +166,25 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c8__avx512skx( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -194,19 +193,17 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512vnni-prfm.c index 3595ebeb2c4..81a2e793dde 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512vnni-prfm.c @@ -65,7 +65,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -187,25 +186,25 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c8__avx512vnni_prfm( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -214,19 +213,17 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512vnni.c index 9c9783a88a0..3531a416cd9 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-avx512vnni.c @@ -64,7 +64,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -180,25 +179,25 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c8__avx512vnni( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -207,19 +206,17 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c8__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c4-minmax-fp32-avx512vnni-prfm.c index 90357286458..dc0591ac900 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c4-minmax-fp32-avx512vnni-prfm.c @@ -71,7 +71,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -122,11 +121,11 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -143,21 +142,21 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -173,35 +172,35 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( vscaled3x0123456789ABCDEF = _mm512_min_ps(vscaled3x0123456789ABCDEF, voutput_max_less_zero_point); vscaled4x0123456789ABCDEF = _mm512_min_ps(vscaled4x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -211,22 +210,20 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c4-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c4-minmax-fp32-avx512vnni.c index 9f5c4308cb1..e207581dba8 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c4-minmax-fp32-avx512vnni.c @@ -70,7 +70,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -119,11 +118,11 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -139,21 +138,21 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -169,35 +168,35 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni( vscaled3x0123456789ABCDEF = _mm512_min_ps(vscaled3x0123456789ABCDEF, voutput_max_less_zero_point); vscaled4x0123456789ABCDEF = _mm512_min_ps(vscaled4x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -207,22 +206,20 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512skx-prfm.c index e681b36125e..98422bba102 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512skx-prfm.c @@ -69,7 +69,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -194,29 +193,29 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -226,22 +225,20 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512skx.c index 910513bdd24..cb051a4bfd9 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512skx.c @@ -68,7 +68,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -191,29 +190,29 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512skx( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -223,22 +222,20 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512vnni-prfm.c index 9d0bf91c58d..e4f8a094558 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512vnni-prfm.c @@ -71,7 +71,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -213,29 +212,29 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni_prfm( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -245,22 +244,20 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512vnni.c index a7377a54c20..e40c92c5a1d 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x16c8-minmax-fp32-avx512vnni.c @@ -70,7 +70,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -206,29 +205,29 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -238,22 +237,20 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c4-minmax-fp32-avx512vnni-prfm.c index 78f6ae37285..5b9e40f4aca 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c4-minmax-fp32-avx512vnni-prfm.c @@ -77,7 +77,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -135,12 +134,12 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -159,23 +158,23 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -193,40 +192,40 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( vscaled4x0123456789ABCDEF = _mm512_min_ps(vscaled4x0123456789ABCDEF, voutput_max_less_zero_point); vscaled5x0123456789ABCDEF = _mm512_min_ps(vscaled5x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - __m256i vacc0x5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - vacc0x5x012389AB4567CDEF = _mm256_adds_epi16(vacc0x5x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x5x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x5x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -237,25 +236,23 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c4-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c4-minmax-fp32-avx512vnni.c index c9bb47c3265..be34f4f7a53 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c4-minmax-fp32-avx512vnni.c @@ -76,7 +76,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -132,12 +131,12 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -155,23 +154,23 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -189,40 +188,40 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni( vscaled4x0123456789ABCDEF = _mm512_min_ps(vscaled4x0123456789ABCDEF, voutput_max_less_zero_point); vscaled5x0123456789ABCDEF = _mm512_min_ps(vscaled5x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - __m256i vacc0x5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - vacc0x5x012389AB4567CDEF = _mm256_adds_epi16(vacc0x5x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x5x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x5x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -233,25 +232,23 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c4__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512skx-prfm.c index 06c980966db..a77cbcbba5b 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512skx-prfm.c @@ -75,7 +75,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -218,33 +217,33 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c8__avx512skx_prfm( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -255,25 +254,23 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512skx.c index 007a0f02fd2..51eb24d33c7 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512skx.c @@ -74,7 +74,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -215,33 +214,33 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c8__avx512skx( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -252,25 +251,23 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512vnni-prfm.c index a6d16a77f6b..937d4000db8 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512vnni-prfm.c @@ -77,7 +77,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -239,33 +238,33 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c8__avx512vnni_prfm( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -276,25 +275,23 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512vnni.c index 6818a4587b3..530e687736d 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-6x16c8-minmax-fp32-avx512vnni.c @@ -76,7 +76,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -232,33 +231,33 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c8__avx512vnni( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -269,25 +268,23 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c8__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512vnni-prfm.c index 7e8707cf84b..5971b5962cb 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512vnni-prfm.c @@ -83,7 +83,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -148,13 +147,13 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -175,25 +174,25 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -213,45 +212,45 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( vscaled5x0123456789ABCDEF = _mm512_min_ps(vscaled5x0123456789ABCDEF, voutput_max_less_zero_point); vscaled6x0123456789ABCDEF = _mm512_min_ps(vscaled6x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - __m256i vacc0x5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x5x0123456789ABCDEF, 1)); - __m256i vacc0x6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - vacc0x5x012389AB4567CDEF = _mm256_adds_epi16(vacc0x5x012389AB4567CDEF, voutput_zero_point); - vacc0x6x012389AB4567CDEF = _mm256_adds_epi16(vacc0x6x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x5x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x5x012389AB4567CDEF, 1)); - const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x6x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x6x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -263,28 +262,26 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512vnni.c index 316c2aa6019..2250d5c05c5 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512vnni.c @@ -82,7 +82,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -145,13 +144,13 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -171,25 +170,25 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -209,45 +208,45 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni( vscaled5x0123456789ABCDEF = _mm512_min_ps(vscaled5x0123456789ABCDEF, voutput_max_less_zero_point); vscaled6x0123456789ABCDEF = _mm512_min_ps(vscaled6x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - __m256i vacc0x5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x5x0123456789ABCDEF, 1)); - __m256i vacc0x6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - vacc0x5x012389AB4567CDEF = _mm256_adds_epi16(vacc0x5x012389AB4567CDEF, voutput_zero_point); - vacc0x6x012389AB4567CDEF = _mm256_adds_epi16(vacc0x6x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x5x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x5x012389AB4567CDEF, 1)); - const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x6x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x6x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -259,28 +258,26 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512skx-prfm.c index 3d4c9793c52..aa3ccbde45e 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512skx-prfm.c @@ -81,7 +81,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -242,37 +241,37 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -284,28 +283,26 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512skx.c index 555b5e813f1..194316a63d5 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512skx.c @@ -80,7 +80,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -239,37 +238,37 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512skx( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -281,28 +280,26 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512vnni-prfm.c index 37a48129ad2..8bbddaf5c66 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512vnni-prfm.c @@ -83,7 +83,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -265,37 +264,37 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -307,28 +306,26 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512vnni.c index 2ab75f72cdb..799c3ec0ce3 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c8-minmax-fp32-avx512vnni.c @@ -82,7 +82,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -258,37 +257,37 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -300,28 +299,26 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c4-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c4-minmax-fp32-avx512vnni-prfm.c index bc5600b685d..508fa7edba7 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c4-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c4-minmax-fp32-avx512vnni-prfm.c @@ -89,7 +89,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -161,14 +160,14 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -191,27 +190,27 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); xnn_prefetch_to_l1((const int8_t*) w + 960); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); - __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x7x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); + __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc7x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -233,50 +232,50 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( vscaled6x0123456789ABCDEF = _mm512_min_ps(vscaled6x0123456789ABCDEF, voutput_max_less_zero_point); vscaled7x0123456789ABCDEF = _mm512_min_ps(vscaled7x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); + vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - __m256i vacc0x5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x5x0123456789ABCDEF, 1)); - __m256i vacc0x6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x6x0123456789ABCDEF, 1)); - __m256i vacc0x7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - vacc0x5x012389AB4567CDEF = _mm256_adds_epi16(vacc0x5x012389AB4567CDEF, voutput_zero_point); - vacc0x6x012389AB4567CDEF = _mm256_adds_epi16(vacc0x6x012389AB4567CDEF, voutput_zero_point); - vacc0x7x012389AB4567CDEF = _mm256_adds_epi16(vacc0x7x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x5x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x5x012389AB4567CDEF, 1)); - const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x6x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x6x012389AB4567CDEF, 1)); - const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x7x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x7x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -289,31 +288,29 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c4-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c4-minmax-fp32-avx512vnni.c index dcb3c713542..fe84f746b92 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c4-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c4-minmax-fp32-avx512vnni.c @@ -88,7 +88,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0x0123456789ABCDEF = _mm512_setzero_epi32(); @@ -158,14 +157,14 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni( w = (const int8_t*) w + 128; k -= 8 * sizeof(int8_t); } - vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc0x3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc0x4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc0x5x0123456789ABCDEF, vacc1x5x0123456789ABCDEF); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc0x6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc0x7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); @@ -187,27 +186,27 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni( const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); - vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc4x0123456789ABCDEF, va4x0123, vb0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc5x0123456789ABCDEF, va5x0123, vb0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc6x0123456789ABCDEF, va6x0123, vb0123456789ABCDEF); + vacc7x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc7x0123456789ABCDEF, va7x0123, vb0123456789ABCDEF); w = (const int8_t*) w + 64; k -= 4 * sizeof(int8_t); } - __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF); - __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF); - __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF); - __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x3x0123456789ABCDEF); - __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x4x0123456789ABCDEF); - __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x5x0123456789ABCDEF); - __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x6x0123456789ABCDEF); - __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x7x0123456789ABCDEF); + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); + __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); + __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); + __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc7x0123456789ABCDEF); const __m512 vscale012345678ABCDEF = _mm512_load_ps(w); w = (const float*) w + 16; @@ -229,50 +228,50 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni( vscaled6x0123456789ABCDEF = _mm512_min_ps(vscaled6x0123456789ABCDEF, voutput_max_less_zero_point); vscaled7x0123456789ABCDEF = _mm512_min_ps(vscaled7x0123456789ABCDEF, voutput_max_less_zero_point); - vacc0x0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - vacc0x1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - vacc0x2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - vacc0x3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - vacc0x4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - vacc0x5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - vacc0x6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - vacc0x7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); + vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); + vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); + vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); + vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); + vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0x0123456789ABCDEF, 1)); - __m256i vacc0x1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x1x0123456789ABCDEF, 1)); - __m256i vacc0x2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x2x0123456789ABCDEF, 1)); - __m256i vacc0x3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x3x0123456789ABCDEF, 1)); - __m256i vacc0x4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x4x0123456789ABCDEF, 1)); - __m256i vacc0x5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x5x0123456789ABCDEF, 1)); - __m256i vacc0x6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x6x0123456789ABCDEF, 1)); - __m256i vacc0x7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x0x012389AB4567CDEF, voutput_zero_point); - vacc0x1x012389AB4567CDEF = _mm256_adds_epi16(vacc0x1x012389AB4567CDEF, voutput_zero_point); - vacc0x2x012389AB4567CDEF = _mm256_adds_epi16(vacc0x2x012389AB4567CDEF, voutput_zero_point); - vacc0x3x012389AB4567CDEF = _mm256_adds_epi16(vacc0x3x012389AB4567CDEF, voutput_zero_point); - vacc0x4x012389AB4567CDEF = _mm256_adds_epi16(vacc0x4x012389AB4567CDEF, voutput_zero_point); - vacc0x5x012389AB4567CDEF = _mm256_adds_epi16(vacc0x5x012389AB4567CDEF, voutput_zero_point); - vacc0x6x012389AB4567CDEF = _mm256_adds_epi16(vacc0x6x012389AB4567CDEF, voutput_zero_point); - vacc0x7x012389AB4567CDEF = _mm256_adds_epi16(vacc0x7x012389AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x0x012389AB4567CDEF, 1)); - const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x1x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x1x012389AB4567CDEF, 1)); - const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x2x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x2x012389AB4567CDEF, 1)); - const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x3x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x3x012389AB4567CDEF, 1)); - const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x4x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x4x012389AB4567CDEF, 1)); - const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x5x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x5x012389AB4567CDEF, 1)); - const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x6x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x6x012389AB4567CDEF, 1)); - const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x7x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x7x012389AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x012389AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x012389AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -285,31 +284,29 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512skx-prfm.c index 51da52b2a41..2a6540bec2f 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512skx-prfm.c @@ -87,7 +87,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -266,41 +265,41 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -313,31 +312,29 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512skx.c index ff4ee738faf..15dcd799caf 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512skx.c @@ -86,7 +86,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512skx( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -263,41 +262,41 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512skx( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -310,31 +309,29 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512vnni-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512vnni-prfm.c index 192a776795c..e608771c9e4 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512vnni-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512vnni-prfm.c @@ -89,7 +89,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -291,41 +290,41 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni_prfm( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -338,31 +337,29 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512vnni.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512vnni.c index 97a33c2300f..a7fba47cb31 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512vnni.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-8x16c8-minmax-fp32-avx512vnni.c @@ -88,7 +88,6 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -284,41 +283,41 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); @@ -331,31 +330,29 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const int8_t*) ((uintptr_t) a0 - kc); - a1 = (const int8_t*) ((uintptr_t) a1 - kc); - a2 = (const int8_t*) ((uintptr_t) a2 - kc); - a3 = (const int8_t*) ((uintptr_t) a3 - kc); - a4 = (const int8_t*) ((uintptr_t) a4 - kc); - a5 = (const int8_t*) ((uintptr_t) a5 - kc); - a6 = (const int8_t*) ((uintptr_t) a6 - kc); - a7 = (const int8_t*) ((uintptr_t) a7 - kc); - c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const int8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const int8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const int8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (int8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const int8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c4-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c4-minmax-avx512vnni-prfm.c index 2e346351bbd..2772effa231 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c4-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c4-minmax-avx512vnni-prfm.c @@ -48,7 +48,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); w = (const int32_t*) w + 16; @@ -102,13 +101,13 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c4-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c4-minmax-avx512vnni.c index ed4220a075f..a9a5f0d2660 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c4-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c4-minmax-avx512vnni.c @@ -47,7 +47,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); w = (const int32_t*) w + 16; @@ -98,13 +97,13 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__avx512vnni( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-avx512vnni-prfm.c index affc10717c6..3f17577aca6 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-avx512vnni-prfm.c @@ -48,7 +48,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -121,13 +120,13 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-avx512vnni.c index e181b6be100..94531c9e203 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-avx512vnni.c @@ -47,7 +47,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -114,13 +113,13 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-fp32-avx512skx-prfm.c index 8875a2a57e3..cb98a3463f1 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-fp32-avx512skx-prfm.c @@ -106,13 +106,13 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-fp32-avx512skx.c index d68729cb62d..10672227431 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-fp32-avx512skx.c @@ -103,13 +103,13 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512skx( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c4-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c4-minmax-avx512vnni-prfm.c index 57eecb0517c..4555501a671 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c4-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c4-minmax-avx512vnni-prfm.c @@ -52,7 +52,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -123,17 +122,17 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16c4__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c4-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c4-minmax-avx512vnni.c index ef43ad3a7ca..d86fcb2bfc1 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c4-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c4-minmax-avx512vnni.c @@ -51,7 +51,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -119,17 +118,17 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16c4__avx512vnni( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-avx512vnni-prfm.c index c1de5c3e032..8257a30ad46 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-avx512vnni-prfm.c @@ -52,7 +52,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -149,17 +148,17 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16c8__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-avx512vnni.c index 24d7697f5b9..e857053e6b5 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-avx512vnni.c @@ -51,7 +51,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -142,17 +141,17 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16c8__avx512vnni( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-fp32-avx512skx-prfm.c index 887abe2c029..f24d2cb9d46 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-fp32-avx512skx-prfm.c @@ -132,17 +132,17 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-fp32-avx512skx.c index 4400c71794f..8e3b6757c2c 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x16c8-minmax-fp32-avx512skx.c @@ -129,17 +129,17 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16c8__avx512skx( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c4-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c4-minmax-avx512vnni-prfm.c index b0fa04518a6..4100601800f 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c4-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c4-minmax-avx512vnni-prfm.c @@ -56,7 +56,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -144,21 +143,21 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c4__avx512vnni_prfm( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c4-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c4-minmax-avx512vnni.c index 9ae03714da7..0a0700d5859 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c4-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c4-minmax-avx512vnni.c @@ -55,7 +55,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -140,21 +139,21 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c4__avx512vnni( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-avx512vnni-prfm.c index 2484d9c06aa..ae0a2eb59af 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-avx512vnni-prfm.c @@ -56,7 +56,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -177,21 +176,21 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c8__avx512vnni_prfm( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-avx512vnni.c index 40cefd591c5..0c02a528bd7 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-avx512vnni.c @@ -55,7 +55,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -170,21 +169,21 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c8__avx512vnni( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-fp32-avx512skx-prfm.c index 830453009cd..3dd33e936b0 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-fp32-avx512skx-prfm.c @@ -158,21 +158,21 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c8__avx512skx_prfm( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-fp32-avx512skx.c index d4a391dcea0..b9a679a6b7f 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x16c8-minmax-fp32-avx512skx.c @@ -155,21 +155,21 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c8__avx512skx( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-avx512vnni-prfm.c index 98d3daaefa4..2b3c6e5f532 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-avx512vnni-prfm.c @@ -60,7 +60,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -165,25 +164,25 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-avx512vnni.c index 4b2eacc630f..b983d62dac3 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-avx512vnni.c @@ -59,7 +59,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -161,25 +160,25 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__avx512vnni( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-avx512vnni-prfm.c index 3ccfc003663..89bcdd259cc 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-avx512vnni-prfm.c @@ -60,7 +60,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -205,25 +204,25 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c8__avx512vnni_prfm( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-avx512vnni.c index 6c776ba4396..9ba85cfd164 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-avx512vnni.c @@ -59,7 +59,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -198,25 +197,25 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c8__avx512vnni( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-fp32-avx512skx-prfm.c index 2bada619ff4..9491d131c50 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-fp32-avx512skx-prfm.c @@ -184,25 +184,25 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c8__avx512skx_prfm( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-fp32-avx512skx.c index df06b6b7f1c..eb8f9bfe190 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c8-minmax-fp32-avx512skx.c @@ -181,25 +181,25 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c8__avx512skx( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c4-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c4-minmax-avx512vnni-prfm.c index bc7de9bc279..9864dbca657 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c4-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c4-minmax-avx512vnni-prfm.c @@ -64,7 +64,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -186,29 +185,29 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c4-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c4-minmax-avx512vnni.c index 460df3f8dc4..33cfe57ce99 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c4-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c4-minmax-avx512vnni.c @@ -63,7 +63,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -182,29 +181,29 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c4__avx512vnni( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-avx512vnni-prfm.c index 5269107631b..5fff41fcf1d 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-avx512vnni-prfm.c @@ -64,7 +64,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -233,29 +232,29 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c8__avx512vnni_prfm( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-avx512vnni.c index d8e78f0bad0..63c9787b743 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-avx512vnni.c @@ -63,7 +63,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -226,29 +225,29 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c8__avx512vnni( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-fp32-avx512skx-prfm.c index ba09c972119..ca8ac456c4c 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-fp32-avx512skx-prfm.c @@ -210,29 +210,29 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-fp32-avx512skx.c index 8a35b141b35..b83ec7517f4 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-5x16c8-minmax-fp32-avx512skx.c @@ -207,29 +207,29 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c8__avx512skx( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c4-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c4-minmax-avx512vnni-prfm.c index 611a602c966..244044cbc62 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c4-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c4-minmax-avx512vnni-prfm.c @@ -68,7 +68,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -207,33 +206,33 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16c4__avx512vnni_prfm( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c4-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c4-minmax-avx512vnni.c index e3ae284eec0..bad01c0bd2a 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c4-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c4-minmax-avx512vnni.c @@ -67,7 +67,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -203,33 +202,33 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16c4__avx512vnni( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-avx512vnni-prfm.c index 271b3f150da..dafe4cd961a 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-avx512vnni-prfm.c @@ -68,7 +68,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -261,33 +260,33 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16c8__avx512vnni_prfm( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-avx512vnni.c index 84edb552ff8..611828ac8c0 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-avx512vnni.c @@ -67,7 +67,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -254,33 +253,33 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16c8__avx512vnni( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-fp32-avx512skx-prfm.c index b1c47ab476a..75c2cd2f7b9 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-fp32-avx512skx-prfm.c @@ -236,33 +236,33 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16c8__avx512skx_prfm( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-fp32-avx512skx.c index 31fe1a861d7..f3df215a0bd 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-6x16c8-minmax-fp32-avx512skx.c @@ -233,33 +233,33 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16c8__avx512skx( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c4-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c4-minmax-avx512vnni-prfm.c index 6a42dab7418..560756a1d5a 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c4-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c4-minmax-avx512vnni-prfm.c @@ -72,7 +72,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -228,37 +227,37 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c4-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c4-minmax-avx512vnni.c index c7232a81656..5a1bafd95a0 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c4-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c4-minmax-avx512vnni.c @@ -71,7 +71,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -224,37 +223,37 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c4__avx512vnni( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-avx512vnni-prfm.c index 01c21bc2e2d..53a42d871c2 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-avx512vnni-prfm.c @@ -72,7 +72,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -289,37 +288,37 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-avx512vnni.c index b460c919bf7..e739cdc85a8 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-avx512vnni.c @@ -71,7 +71,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -282,37 +281,37 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-fp32-avx512skx-prfm.c index 5d386b25d92..2400931f46c 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-fp32-avx512skx-prfm.c @@ -262,37 +262,37 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-fp32-avx512skx.c index 8e1711cf6ae..e14d2040c19 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c8-minmax-fp32-avx512skx.c @@ -259,37 +259,37 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512skx( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c4-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c4-minmax-avx512vnni-prfm.c index c02415f39b8..ab4a1caf035 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c4-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c4-minmax-avx512vnni-prfm.c @@ -76,7 +76,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -249,41 +248,41 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c4-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c4-minmax-avx512vnni.c index 41f758b88ab..26596b2b5e8 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c4-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c4-minmax-avx512vnni.c @@ -75,7 +75,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c4__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x0123456789ABCDEF = _mm512_load_epi32(w); __m512i vacc1x0123456789ABCDEF = vacc0x0123456789ABCDEF; @@ -245,41 +244,41 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c4__avx512vnni( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-avx512vnni-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-avx512vnni-prfm.c index 56d90dc5edc..3926b6ba455 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-avx512vnni-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-avx512vnni-prfm.c @@ -76,7 +76,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c8__avx512vnni_prfm( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -317,41 +316,41 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c8__avx512vnni_prfm( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-avx512vnni.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-avx512vnni.c index 946d506dc32..f2e8029cf3e 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-avx512vnni.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-avx512vnni.c @@ -75,7 +75,6 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c8__avx512vnni( const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512vnni.output_max_less_zero_point); const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512vnni.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512vnni.output_min); - const __m128i vshuffle_control_mask = _mm_loadu_si128((const __m128i*) params->fp32_avx512vnni.shuffle_control_mask); do { __m512i vacc0x01234567 = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) w)); __m512i vacc0x89ABCDEF = _mm512_cvtepu32_epi64(_mm256_load_si256((const __m256i*) ((const int32_t*) w + 8))); @@ -310,41 +309,41 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c8__avx512vnni( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-fp32-avx512skx-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-fp32-avx512skx-prfm.c index 5b7a648318a..b6dbb47ebfc 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-fp32-avx512skx-prfm.c @@ -288,41 +288,41 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-fp32-avx512skx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-fp32-avx512skx.c index 4a9d127f7bb..3b410116c4c 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-fp32-avx512skx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x16c8-minmax-fp32-avx512skx.c @@ -285,41 +285,41 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c8__avx512skx( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packs_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epi8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-gemm/gen/qu8-gemm-1x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-gemm/gen/qu8-gemm-1x16c8-minmax-fp32-avx512skx-prfm.c index b0c1230b718..42369122d5c 100644 --- a/src/qu8-gemm/gen/qu8-gemm-1x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-gemm/gen/qu8-gemm-1x16c8-minmax-fp32-avx512skx-prfm.c @@ -47,7 +47,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -98,22 +97,20 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-1x16c8-minmax-fp32-avx512skx.c b/src/qu8-gemm/gen/qu8-gemm-1x16c8-minmax-fp32-avx512skx.c index 2190cfa92c5..408552ea273 100644 --- a/src/qu8-gemm/gen/qu8-gemm-1x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-gemm/gen/qu8-gemm-1x16c8-minmax-fp32-avx512skx.c @@ -46,7 +46,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -95,22 +94,20 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-2x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-gemm/gen/qu8-gemm-2x16c8-minmax-fp32-avx512skx-prfm.c index 78ff470cc46..d6844185192 100644 --- a/src/qu8-gemm/gen/qu8-gemm-2x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-gemm/gen/qu8-gemm-2x16c8-minmax-fp32-avx512skx-prfm.c @@ -53,7 +53,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx_prfm( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -122,30 +121,28 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-2x16c8-minmax-fp32-avx512skx.c b/src/qu8-gemm/gen/qu8-gemm-2x16c8-minmax-fp32-avx512skx.c index 7e5e548ba10..70f6717255d 100644 --- a/src/qu8-gemm/gen/qu8-gemm-2x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-gemm/gen/qu8-gemm-2x16c8-minmax-fp32-avx512skx.c @@ -52,7 +52,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -119,30 +118,28 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-3x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-gemm/gen/qu8-gemm-3x16c8-minmax-fp32-avx512skx-prfm.c index 5cf707c20fb..885ee959e7f 100644 --- a/src/qu8-gemm/gen/qu8-gemm-3x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-gemm/gen/qu8-gemm-3x16c8-minmax-fp32-avx512skx-prfm.c @@ -59,7 +59,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx_prfm( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -146,21 +145,21 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx_prfm( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); @@ -168,16 +167,14 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - a2 = (const uint8_t*) ((uintptr_t) a2 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const uint8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-3x16c8-minmax-fp32-avx512skx.c b/src/qu8-gemm/gen/qu8-gemm-3x16c8-minmax-fp32-avx512skx.c index 2513a2d9f68..864813ab22d 100644 --- a/src/qu8-gemm/gen/qu8-gemm-3x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-gemm/gen/qu8-gemm-3x16c8-minmax-fp32-avx512skx.c @@ -58,7 +58,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -143,21 +142,21 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); @@ -165,16 +164,14 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - a2 = (const uint8_t*) ((uintptr_t) a2 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const uint8_t*) ((uintptr_t) a2 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-4x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-gemm/gen/qu8-gemm-4x16c8-minmax-fp32-avx512skx-prfm.c index 5470246585a..c9c5130dba1 100644 --- a/src/qu8-gemm/gen/qu8-gemm-4x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-gemm/gen/qu8-gemm-4x16c8-minmax-fp32-avx512skx-prfm.c @@ -65,7 +65,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx_prfm( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -170,25 +169,25 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx_prfm( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); @@ -197,19 +196,17 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - a2 = (const uint8_t*) ((uintptr_t) a2 - kc); - a3 = (const uint8_t*) ((uintptr_t) a3 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const uint8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const uint8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-4x16c8-minmax-fp32-avx512skx.c b/src/qu8-gemm/gen/qu8-gemm-4x16c8-minmax-fp32-avx512skx.c index a9ebd5fa83d..2962132ba82 100644 --- a/src/qu8-gemm/gen/qu8-gemm-4x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-gemm/gen/qu8-gemm-4x16c8-minmax-fp32-avx512skx.c @@ -64,7 +64,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -167,25 +166,25 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); @@ -194,19 +193,17 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - a2 = (const uint8_t*) ((uintptr_t) a2 - kc); - a3 = (const uint8_t*) ((uintptr_t) a3 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const uint8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const uint8_t*) ((uintptr_t) a3 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-5x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-gemm/gen/qu8-gemm-5x16c8-minmax-fp32-avx512skx-prfm.c index fb34dcfbbf8..5abb1f7592c 100644 --- a/src/qu8-gemm/gen/qu8-gemm-5x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-gemm/gen/qu8-gemm-5x16c8-minmax-fp32-avx512skx-prfm.c @@ -71,7 +71,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -194,29 +193,29 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); @@ -226,22 +225,20 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - a2 = (const uint8_t*) ((uintptr_t) a2 - kc); - a3 = (const uint8_t*) ((uintptr_t) a3 - kc); - a4 = (const uint8_t*) ((uintptr_t) a4 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const uint8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const uint8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (uint8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const uint8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-5x16c8-minmax-fp32-avx512skx.c b/src/qu8-gemm/gen/qu8-gemm-5x16c8-minmax-fp32-avx512skx.c index 31b975bd3d4..0955639d2b3 100644 --- a/src/qu8-gemm/gen/qu8-gemm-5x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-gemm/gen/qu8-gemm-5x16c8-minmax-fp32-avx512skx.c @@ -70,7 +70,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -191,29 +190,29 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); @@ -223,22 +222,20 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - a2 = (const uint8_t*) ((uintptr_t) a2 - kc); - a3 = (const uint8_t*) ((uintptr_t) a3 - kc); - a4 = (const uint8_t*) ((uintptr_t) a4 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const uint8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const uint8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (uint8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const uint8_t*) ((uintptr_t) a4 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-6x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-gemm/gen/qu8-gemm-6x16c8-minmax-fp32-avx512skx-prfm.c index 1a8e82700b3..c832f516c00 100644 --- a/src/qu8-gemm/gen/qu8-gemm-6x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-gemm/gen/qu8-gemm-6x16c8-minmax-fp32-avx512skx-prfm.c @@ -77,7 +77,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_6x16c8__avx512skx_prfm( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -218,33 +217,33 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_6x16c8__avx512skx_prfm( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); @@ -255,25 +254,23 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_6x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - a2 = (const uint8_t*) ((uintptr_t) a2 - kc); - a3 = (const uint8_t*) ((uintptr_t) a3 - kc); - a4 = (const uint8_t*) ((uintptr_t) a4 - kc); - a5 = (const uint8_t*) ((uintptr_t) a5 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const uint8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const uint8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (uint8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const uint8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (uint8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const uint8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-6x16c8-minmax-fp32-avx512skx.c b/src/qu8-gemm/gen/qu8-gemm-6x16c8-minmax-fp32-avx512skx.c index f2bae766b75..a50338a20cf 100644 --- a/src/qu8-gemm/gen/qu8-gemm-6x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-gemm/gen/qu8-gemm-6x16c8-minmax-fp32-avx512skx.c @@ -76,7 +76,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_6x16c8__avx512skx( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -215,33 +214,33 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_6x16c8__avx512skx( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); @@ -252,25 +251,23 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_6x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - a2 = (const uint8_t*) ((uintptr_t) a2 - kc); - a3 = (const uint8_t*) ((uintptr_t) a3 - kc); - a4 = (const uint8_t*) ((uintptr_t) a4 - kc); - a5 = (const uint8_t*) ((uintptr_t) a5 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const uint8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const uint8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (uint8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const uint8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (uint8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const uint8_t*) ((uintptr_t) a5 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-7x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-gemm/gen/qu8-gemm-7x16c8-minmax-fp32-avx512skx-prfm.c index 047c87e2ad8..1ebf300d80b 100644 --- a/src/qu8-gemm/gen/qu8-gemm-7x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-gemm/gen/qu8-gemm-7x16c8-minmax-fp32-avx512skx-prfm.c @@ -83,7 +83,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -242,37 +241,37 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); @@ -284,28 +283,26 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - a2 = (const uint8_t*) ((uintptr_t) a2 - kc); - a3 = (const uint8_t*) ((uintptr_t) a3 - kc); - a4 = (const uint8_t*) ((uintptr_t) a4 - kc); - a5 = (const uint8_t*) ((uintptr_t) a5 - kc); - a6 = (const uint8_t*) ((uintptr_t) a6 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const uint8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const uint8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (uint8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const uint8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (uint8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const uint8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (uint8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const uint8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-7x16c8-minmax-fp32-avx512skx.c b/src/qu8-gemm/gen/qu8-gemm-7x16c8-minmax-fp32-avx512skx.c index a8e3fbdbd19..71222c59a1c 100644 --- a/src/qu8-gemm/gen/qu8-gemm-7x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-gemm/gen/qu8-gemm-7x16c8-minmax-fp32-avx512skx.c @@ -82,7 +82,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -239,37 +238,37 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); @@ -281,28 +280,26 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - a2 = (const uint8_t*) ((uintptr_t) a2 - kc); - a3 = (const uint8_t*) ((uintptr_t) a3 - kc); - a4 = (const uint8_t*) ((uintptr_t) a4 - kc); - a5 = (const uint8_t*) ((uintptr_t) a5 - kc); - a6 = (const uint8_t*) ((uintptr_t) a6 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const uint8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const uint8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (uint8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const uint8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (uint8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const uint8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (uint8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const uint8_t*) ((uintptr_t) a6 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-8x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-gemm/gen/qu8-gemm-8x16c8-minmax-fp32-avx512skx-prfm.c index e0dadcbc128..8f04efe54c4 100644 --- a/src/qu8-gemm/gen/qu8-gemm-8x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-gemm/gen/qu8-gemm-8x16c8-minmax-fp32-avx512skx-prfm.c @@ -89,7 +89,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -266,41 +265,41 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); @@ -313,31 +312,29 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - a2 = (const uint8_t*) ((uintptr_t) a2 - kc); - a3 = (const uint8_t*) ((uintptr_t) a3 - kc); - a4 = (const uint8_t*) ((uintptr_t) a4 - kc); - a5 = (const uint8_t*) ((uintptr_t) a5 - kc); - a6 = (const uint8_t*) ((uintptr_t) a6 - kc); - a7 = (const uint8_t*) ((uintptr_t) a7 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const uint8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const uint8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (uint8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const uint8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (uint8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const uint8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (uint8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const uint8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (uint8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const uint8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qu8-gemm/gen/qu8-gemm-8x16c8-minmax-fp32-avx512skx.c b/src/qu8-gemm/gen/qu8-gemm-8x16c8-minmax-fp32-avx512skx.c index f313369d051..7eb4bef1ee8 100644 --- a/src/qu8-gemm/gen/qu8-gemm-8x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-gemm/gen/qu8-gemm-8x16c8-minmax-fp32-avx512skx.c @@ -88,7 +88,6 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx( const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point); const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min); const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point); - const __m128i vshuffle_control_mask = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); do { __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w); @@ -263,41 +262,41 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi8(vout1x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi8(vout2x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi8(vout3x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi8(vout4x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi8(vout5x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi8(vout6x0123456789AB4567CDEF, vshuffle_control_mask); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi8(vout7x0123456789AB4567CDEF, vshuffle_control_mask); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); @@ -310,31 +309,29 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx( if (nc >= 16) { _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); - _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); - - a0 = (const uint8_t*) ((uintptr_t) a0 - kc); - a1 = (const uint8_t*) ((uintptr_t) a1 - kc); - a2 = (const uint8_t*) ((uintptr_t) a2 - kc); - a3 = (const uint8_t*) ((uintptr_t) a3 - kc); - a4 = (const uint8_t*) ((uintptr_t) a4 - kc); - a5 = (const uint8_t*) ((uintptr_t) a5 - kc); - a6 = (const uint8_t*) ((uintptr_t) a6 - kc); - a7 = (const uint8_t*) ((uintptr_t) a7 - kc); - c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride); + a0 = (const uint8_t*) ((uintptr_t) a0 - kc); + _mm_storeu_si128((__m128i*) c1, vout1x0123456789ABCDEF); c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride); + a1 = (const uint8_t*) ((uintptr_t) a1 - kc); + _mm_storeu_si128((__m128i*) c2, vout2x0123456789ABCDEF); c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride); + a2 = (const uint8_t*) ((uintptr_t) a2 - kc); + _mm_storeu_si128((__m128i*) c3, vout3x0123456789ABCDEF); c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride); + a3 = (const uint8_t*) ((uintptr_t) a3 - kc); + _mm_storeu_si128((__m128i*) c4, vout4x0123456789ABCDEF); c4 = (uint8_t*) ((uintptr_t) c4 + cn_stride); + a4 = (const uint8_t*) ((uintptr_t) a4 - kc); + _mm_storeu_si128((__m128i*) c5, vout5x0123456789ABCDEF); c5 = (uint8_t*) ((uintptr_t) c5 + cn_stride); + a5 = (const uint8_t*) ((uintptr_t) a5 - kc); + _mm_storeu_si128((__m128i*) c6, vout6x0123456789ABCDEF); c6 = (uint8_t*) ((uintptr_t) c6 + cn_stride); + a6 = (const uint8_t*) ((uintptr_t) a6 - kc); + _mm_storeu_si128((__m128i*) c7, vout7x0123456789ABCDEF); c7 = (uint8_t*) ((uintptr_t) c7 + cn_stride); + a7 = (const uint8_t*) ((uintptr_t) a7 - kc); nc -= 16; } else { diff --git a/src/qu8-igemm/gen/qu8-igemm-1x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-igemm/gen/qu8-igemm-1x16c8-minmax-fp32-avx512skx-prfm.c index ca618bc56a7..1979d39f90d 100644 --- a/src/qu8-igemm/gen/qu8-igemm-1x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-igemm/gen/qu8-igemm-1x16c8-minmax-fp32-avx512skx-prfm.c @@ -106,13 +106,13 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-1x16c8-minmax-fp32-avx512skx.c b/src/qu8-igemm/gen/qu8-igemm-1x16c8-minmax-fp32-avx512skx.c index 071e5b53222..3099c61ba7f 100644 --- a/src/qu8-igemm/gen/qu8-igemm-1x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-igemm/gen/qu8-igemm-1x16c8-minmax-fp32-avx512skx.c @@ -103,13 +103,13 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-2x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-igemm/gen/qu8-igemm-2x16c8-minmax-fp32-avx512skx-prfm.c index 47aea7a6390..0df47c54a43 100644 --- a/src/qu8-igemm/gen/qu8-igemm-2x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-igemm/gen/qu8-igemm-2x16c8-minmax-fp32-avx512skx-prfm.c @@ -132,17 +132,17 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx_prfm( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-2x16c8-minmax-fp32-avx512skx.c b/src/qu8-igemm/gen/qu8-igemm-2x16c8-minmax-fp32-avx512skx.c index 907580b62b1..f1e1aaf1651 100644 --- a/src/qu8-igemm/gen/qu8-igemm-2x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-igemm/gen/qu8-igemm-2x16c8-minmax-fp32-avx512skx.c @@ -129,17 +129,17 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx( vacc0x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0x0123456789ABCDEF); vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-3x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-igemm/gen/qu8-igemm-3x16c8-minmax-fp32-avx512skx-prfm.c index 23f44db0c6e..918733e9627 100644 --- a/src/qu8-igemm/gen/qu8-igemm-3x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-igemm/gen/qu8-igemm-3x16c8-minmax-fp32-avx512skx-prfm.c @@ -158,21 +158,21 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx_prfm( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-3x16c8-minmax-fp32-avx512skx.c b/src/qu8-igemm/gen/qu8-igemm-3x16c8-minmax-fp32-avx512skx.c index 29a45d5dac3..291c91bfce9 100644 --- a/src/qu8-igemm/gen/qu8-igemm-3x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-igemm/gen/qu8-igemm-3x16c8-minmax-fp32-avx512skx.c @@ -155,21 +155,21 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx( vacc1x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled1x0123456789ABCDEF); vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-4x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-igemm/gen/qu8-igemm-4x16c8-minmax-fp32-avx512skx-prfm.c index 80b97360b20..e3351b54e15 100644 --- a/src/qu8-igemm/gen/qu8-igemm-4x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-igemm/gen/qu8-igemm-4x16c8-minmax-fp32-avx512skx-prfm.c @@ -184,25 +184,25 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx_prfm( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-4x16c8-minmax-fp32-avx512skx.c b/src/qu8-igemm/gen/qu8-igemm-4x16c8-minmax-fp32-avx512skx.c index 0fbe2c0c00c..c822d995cfe 100644 --- a/src/qu8-igemm/gen/qu8-igemm-4x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-igemm/gen/qu8-igemm-4x16c8-minmax-fp32-avx512skx.c @@ -181,25 +181,25 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx( vacc2x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled2x0123456789ABCDEF); vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-5x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-igemm/gen/qu8-igemm-5x16c8-minmax-fp32-avx512skx-prfm.c index 141b4da5465..4fc5cea9623 100644 --- a/src/qu8-igemm/gen/qu8-igemm-5x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-igemm/gen/qu8-igemm-5x16c8-minmax-fp32-avx512skx-prfm.c @@ -210,29 +210,29 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-5x16c8-minmax-fp32-avx512skx.c b/src/qu8-igemm/gen/qu8-igemm-5x16c8-minmax-fp32-avx512skx.c index 67f0a8c168d..05006e81ec5 100644 --- a/src/qu8-igemm/gen/qu8-igemm-5x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-igemm/gen/qu8-igemm-5x16c8-minmax-fp32-avx512skx.c @@ -207,29 +207,29 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_5x16c8__avx512skx( vacc3x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled3x0123456789ABCDEF); vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-6x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-igemm/gen/qu8-igemm-6x16c8-minmax-fp32-avx512skx-prfm.c index 840684ad7a9..3a5bbe60fa8 100644 --- a/src/qu8-igemm/gen/qu8-igemm-6x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-igemm/gen/qu8-igemm-6x16c8-minmax-fp32-avx512skx-prfm.c @@ -236,33 +236,33 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_6x16c8__avx512skx_prfm( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-6x16c8-minmax-fp32-avx512skx.c b/src/qu8-igemm/gen/qu8-igemm-6x16c8-minmax-fp32-avx512skx.c index 6227fd7e467..44670347fd1 100644 --- a/src/qu8-igemm/gen/qu8-igemm-6x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-igemm/gen/qu8-igemm-6x16c8-minmax-fp32-avx512skx.c @@ -233,33 +233,33 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_6x16c8__avx512skx( vacc4x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled4x0123456789ABCDEF); vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-7x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-igemm/gen/qu8-igemm-7x16c8-minmax-fp32-avx512skx-prfm.c index 4c6449aa522..5a9c30d63f0 100644 --- a/src/qu8-igemm/gen/qu8-igemm-7x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-igemm/gen/qu8-igemm-7x16c8-minmax-fp32-avx512skx-prfm.c @@ -262,37 +262,37 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-7x16c8-minmax-fp32-avx512skx.c b/src/qu8-igemm/gen/qu8-igemm-7x16c8-minmax-fp32-avx512skx.c index 330607f0a7c..3940c395d2a 100644 --- a/src/qu8-igemm/gen/qu8-igemm-7x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-igemm/gen/qu8-igemm-7x16c8-minmax-fp32-avx512skx.c @@ -259,37 +259,37 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_7x16c8__avx512skx( vacc5x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled5x0123456789ABCDEF); vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-8x16c8-minmax-fp32-avx512skx-prfm.c b/src/qu8-igemm/gen/qu8-igemm-8x16c8-minmax-fp32-avx512skx-prfm.c index 14ad84d7472..490b05b9fa8 100644 --- a/src/qu8-igemm/gen/qu8-igemm-8x16c8-minmax-fp32-avx512skx-prfm.c +++ b/src/qu8-igemm/gen/qu8-igemm-8x16c8-minmax-fp32-avx512skx-prfm.c @@ -288,41 +288,41 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/qu8-igemm/gen/qu8-igemm-8x16c8-minmax-fp32-avx512skx.c b/src/qu8-igemm/gen/qu8-igemm-8x16c8-minmax-fp32-avx512skx.c index 9620bc7c74b..b6e62ea5767 100644 --- a/src/qu8-igemm/gen/qu8-igemm-8x16c8-minmax-fp32-avx512skx.c +++ b/src/qu8-igemm/gen/qu8-igemm-8x16c8-minmax-fp32-avx512skx.c @@ -285,41 +285,41 @@ void xnn_qu8_igemm_minmax_fp32_ukernel_8x16c8__avx512skx( vacc6x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled6x0123456789ABCDEF); vacc7x0123456789ABCDEF = _mm512_cvtps_epi32(vscaled7x0123456789ABCDEF); - __m256i vacc0x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); - __m256i vacc1x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); - __m256i vacc2x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); - __m256i vacc3x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); - __m256i vacc4x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); - __m256i vacc5x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); - __m256i vacc6x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); - __m256i vacc7x0123456789AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); + __m256i vacc0x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0x0123456789ABCDEF, 1)); + __m256i vacc1x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc1x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc1x0123456789ABCDEF, 1)); + __m256i vacc2x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc2x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc2x0123456789ABCDEF, 1)); + __m256i vacc3x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc3x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc3x0123456789ABCDEF, 1)); + __m256i vacc4x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc4x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc4x0123456789ABCDEF, 1)); + __m256i vacc5x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc5x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc5x0123456789ABCDEF, 1)); + __m256i vacc6x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc6x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc6x0123456789ABCDEF, 1)); + __m256i vacc7x012389AB4567CDEF = _mm256_packs_epi32(_mm512_castsi512_si256(vacc7x0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc7x0123456789ABCDEF, 1)); - vacc0x0123456789AB4567CDEF = _mm256_adds_epi16(vacc0x0123456789AB4567CDEF, voutput_zero_point); - vacc1x0123456789AB4567CDEF = _mm256_adds_epi16(vacc1x0123456789AB4567CDEF, voutput_zero_point); - vacc2x0123456789AB4567CDEF = _mm256_adds_epi16(vacc2x0123456789AB4567CDEF, voutput_zero_point); - vacc3x0123456789AB4567CDEF = _mm256_adds_epi16(vacc3x0123456789AB4567CDEF, voutput_zero_point); - vacc4x0123456789AB4567CDEF = _mm256_adds_epi16(vacc4x0123456789AB4567CDEF, voutput_zero_point); - vacc5x0123456789AB4567CDEF = _mm256_adds_epi16(vacc5x0123456789AB4567CDEF, voutput_zero_point); - vacc6x0123456789AB4567CDEF = _mm256_adds_epi16(vacc6x0123456789AB4567CDEF, voutput_zero_point); - vacc7x0123456789AB4567CDEF = _mm256_adds_epi16(vacc7x0123456789AB4567CDEF, voutput_zero_point); + vacc0x012389AB4567CDEF = _mm256_adds_epi16(vacc0x012389AB4567CDEF, voutput_zero_point); + vacc1x012389AB4567CDEF = _mm256_adds_epi16(vacc1x012389AB4567CDEF, voutput_zero_point); + vacc2x012389AB4567CDEF = _mm256_adds_epi16(vacc2x012389AB4567CDEF, voutput_zero_point); + vacc3x012389AB4567CDEF = _mm256_adds_epi16(vacc3x012389AB4567CDEF, voutput_zero_point); + vacc4x012389AB4567CDEF = _mm256_adds_epi16(vacc4x012389AB4567CDEF, voutput_zero_point); + vacc5x012389AB4567CDEF = _mm256_adds_epi16(vacc5x012389AB4567CDEF, voutput_zero_point); + vacc6x012389AB4567CDEF = _mm256_adds_epi16(vacc6x012389AB4567CDEF, voutput_zero_point); + vacc7x012389AB4567CDEF = _mm256_adds_epi16(vacc7x012389AB4567CDEF, voutput_zero_point); - const __m128i vout0x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc0x0123456789AB4567CDEF, 1)); - const __m128i vout1x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc1x0123456789AB4567CDEF, 1)); - const __m128i vout2x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc2x0123456789AB4567CDEF, 1)); - const __m128i vout3x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc3x0123456789AB4567CDEF, 1)); - const __m128i vout4x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc4x0123456789AB4567CDEF, 1)); - const __m128i vout5x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc5x0123456789AB4567CDEF, 1)); - const __m128i vout6x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc6x0123456789AB4567CDEF, 1)); - const __m128i vout7x0123456789AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc7x0123456789AB4567CDEF), _mm256_extracti128_si256(vacc7x0123456789AB4567CDEF, 1)); + const __m128i vout0x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x012389AB4567CDEF), _mm256_extracti128_si256(vacc0x012389AB4567CDEF, 1)); + const __m128i vout1x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc1x012389AB4567CDEF), _mm256_extracti128_si256(vacc1x012389AB4567CDEF, 1)); + const __m128i vout2x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc2x012389AB4567CDEF), _mm256_extracti128_si256(vacc2x012389AB4567CDEF, 1)); + const __m128i vout3x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc3x012389AB4567CDEF), _mm256_extracti128_si256(vacc3x012389AB4567CDEF, 1)); + const __m128i vout4x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc4x012389AB4567CDEF), _mm256_extracti128_si256(vacc4x012389AB4567CDEF, 1)); + const __m128i vout5x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc5x012389AB4567CDEF), _mm256_extracti128_si256(vacc5x012389AB4567CDEF, 1)); + const __m128i vout6x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc6x012389AB4567CDEF), _mm256_extracti128_si256(vacc6x012389AB4567CDEF, 1)); + const __m128i vout7x012389AB4567CDEF = _mm_packus_epi16(_mm256_castsi256_si128(vacc7x012389AB4567CDEF), _mm256_extracti128_si256(vacc7x012389AB4567CDEF, 1)); - __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x0123456789AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi32(vout0x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout1x0123456789ABCDEF = _mm_shuffle_epi32(vout1x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout2x0123456789ABCDEF = _mm_shuffle_epi32(vout2x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout3x0123456789ABCDEF = _mm_shuffle_epi32(vout3x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout4x0123456789ABCDEF = _mm_shuffle_epi32(vout4x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout5x0123456789ABCDEF = _mm_shuffle_epi32(vout5x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout6x0123456789ABCDEF = _mm_shuffle_epi32(vout6x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i vout7x0123456789ABCDEF = _mm_shuffle_epi32(vout7x012389AB4567CDEF, _MM_SHUFFLE(3, 1, 2, 0)); vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min); vout1x0123456789ABCDEF = _mm_max_epu8(vout1x0123456789ABCDEF, voutput_min); diff --git a/src/xnnpack/microparams.h b/src/xnnpack/microparams.h index 2486e6d55c8..531def443f8 100644 --- a/src/xnnpack/microparams.h +++ b/src/xnnpack/microparams.h @@ -391,7 +391,6 @@ union xnn_qs8_conv_minmax_params { XNN_ALIGN(64) float scale[16]; XNN_ALIGN(64) float output_max_less_zero_point[16]; XNN_ALIGN(64) int16_t output_zero_point[16]; - XNN_ALIGN(16) int8_t shuffle_control_mask[16]; XNN_ALIGN(16) int8_t output_min[16]; } fp32_avx512vnni; #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64