Skip to content

Commit

Permalink
QS8 AVX512 use vpshufd instead of vpshufb
Browse files Browse the repository at this point in the history
- Saves 1 register and params field

PiperOrigin-RevId: 591304558
  • Loading branch information
fbarchard authored and xnnpack-bot committed Dec 15, 2023
1 parent a68aa0a commit 86e0a93
Show file tree
Hide file tree
Showing 266 changed files with 6,682 additions and 7,146 deletions.
522 changes: 252 additions & 270 deletions src/amalgam/gen/avx512skx.c

Large diffs are not rendered by default.

352 changes: 168 additions & 184 deletions src/amalgam/gen/avx512vnni.c

Large diffs are not rendered by default.

8 changes: 0 additions & 8 deletions src/microparams-init.c
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,6 @@ size_t xnn_init_qs8_qc8w_conv_minmax_fp32_avx512vnni_params(
for (uint32_t i = 0; i < 16; i++) {
params->fp32_avx512vnni.output_zero_point[i] = output_zero_point;
}
const int8_t control_mask[16] = {0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15};
for (uint32_t i = 0; i < 16; i++) {
params->fp32_avx512vnni.shuffle_control_mask[i] = control_mask[i];
}
for (uint32_t i = 0; i < 16; i++) {
params->fp32_avx512vnni.output_min[i] = output_min;
}
Expand All @@ -425,10 +421,6 @@ size_t xnn_init_qs8_conv_minmax_fp32_avx512vnni_params(
for (uint32_t i = 0; i < 16; i++) {
params->fp32_avx512vnni.output_zero_point[i] = output_zero_point;
}
const int8_t control_mask[16] = {0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15};
for (uint32_t i = 0; i < 16; i++) {
params->fp32_avx512vnni.shuffle_control_mask[i] = control_mask[i];
}
for (uint32_t i = 0; i < 16; i++) {
params->fp32_avx512vnni.output_min[i] = output_min;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm(
const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min);
const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max);
const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80
const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0
const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0
do {
const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w);
__m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0);
Expand All @@ -63,8 +63,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm(

const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w);
const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4);
const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask);
const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask);
const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask);
const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask);

vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123);
xnn_prefetch_to_l1((const int8_t*) w + 960);
Expand All @@ -73,25 +73,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm(
w = (const int8_t*) w + 64;
k -= 8 * sizeof(int8_t);
}
vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF);
__m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF);

if (k != 0) {
const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask);
a0 += 4;

const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w);
const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4);
const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask);
const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask);
xnn_prefetch_to_l1((const int8_t*) w + 960);

vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF);
vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF);

w = (const int8_t*) w + 64;
k -= 4 * sizeof(int8_t);
}

vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4);
__m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF);
vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4);
__m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF);

vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni(
const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min);
const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max);
const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80
const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0
const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0
do {
const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w);
__m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0);
Expand All @@ -62,33 +62,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni(

const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w);
const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4);
const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask);
const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask);
const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask);
const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask);

vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123);
vacc1x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0x0123456789ABCDEF, va0x4567, vb0123456789ABCDEFx4567);

w = (const int8_t*) w + 64;
k -= 8 * sizeof(int8_t);
}
vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF);
__m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF);

if (k != 0) {
const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask);
a0 += 4;

const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w);
const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4);
const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask);
const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask);

vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF);
vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF);

w = (const int8_t*) w + 64;
k -= 4 * sizeof(int8_t);
}

vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4);
__m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF);
vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4);
__m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF);

vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni_prfm(
const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min);
const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max);
const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80
const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0
const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0
do {
const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w);
__m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0);
Expand All @@ -75,8 +75,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni_prfm(

const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w);
const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4);
const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask);
const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask);
const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask);
const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask);

vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123);
vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123);
Expand All @@ -87,8 +87,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni_prfm(
w = (const int8_t*) w + 64;
k -= 8 * sizeof(int8_t);
}
vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF);
vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF);
__m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF);
__m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF);

if (k != 0) {
const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask);
Expand All @@ -98,20 +98,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni_prfm(

const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w);
const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4);
const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask);
const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask);
xnn_prefetch_to_l1((const int8_t*) w + 960);

vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF);
vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF);
vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF);
vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF);

w = (const int8_t*) w + 64;
k -= 4 * sizeof(int8_t);
}

vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4);
vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4);
__m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF);
__m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF);
vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4);
vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4);
__m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF);
__m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF);

vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale));
vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni(
const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min);
const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max);
const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80
const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0
const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0
do {
const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w);
__m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0);
Expand All @@ -74,8 +74,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni(

const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w);
const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4);
const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask);
const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask);
const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask);
const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask);

vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123);
vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123);
Expand All @@ -85,8 +85,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni(
w = (const int8_t*) w + 64;
k -= 8 * sizeof(int8_t);
}
vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF);
vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF);
__m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF);
__m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF);

if (k != 0) {
const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask);
Expand All @@ -96,19 +96,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__avx512vnni(

const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w);
const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4);
const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask);
const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask);

vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF);
vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF);
vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF);
vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF);

w = (const int8_t*) w + 64;
k -= 4 * sizeof(int8_t);
}

vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4);
vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4);
__m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF);
__m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF);
vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4);
vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4);
__m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF);
__m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF);

vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale));
vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__avx512vnni_prfm(
const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min);
const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max);
const __m512i vsign_mask = _mm512_set1_epi8(params->avx512vnni.sign_mask); // 0x80
const __m512i vvalue_mask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0
const __m512i vmask = _mm512_set1_epi8(params->avx512vnni.mask); // 0xF0
do {
const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w);
__m512i vacc0x0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0);
Expand All @@ -87,8 +87,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__avx512vnni_prfm(

const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w);
const __m512i vbs0123456789ABCDEFx0123 = _mm512_slli_epi32(vbb0123456789ABCDEFx01234567, 4);
const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vvalue_mask);
const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vvalue_mask);
const __m512i vb0123456789ABCDEFx4567 = _mm512_and_si512(vbb0123456789ABCDEFx01234567, vmask);
const __m512i vb0123456789ABCDEFx0123 = _mm512_and_si512(vbs0123456789ABCDEFx0123, vmask);

vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123);
vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123);
Expand All @@ -101,9 +101,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__avx512vnni_prfm(
w = (const int8_t*) w + 64;
k -= 8 * sizeof(int8_t);
}
vacc0x0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF);
vacc0x1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF);
vacc0x2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF);
__m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF);
__m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc0x1x0123456789ABCDEF, vacc1x1x0123456789ABCDEF);
__m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc0x2x0123456789ABCDEF, vacc1x2x0123456789ABCDEF);

if (k != 0) {
const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask);
Expand All @@ -115,23 +115,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__avx512vnni_prfm(

const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w);
const __m512i vbs0123456789ABCDEF = _mm512_slli_epi32(vbb0123456789ABCDEF, 4);
const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vvalue_mask);
const __m512i vb0123456789ABCDEF = _mm512_and_si512(vbs0123456789ABCDEF, vmask);
xnn_prefetch_to_l1((const int8_t*) w + 960);

vacc0x0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF);
vacc0x1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF);
vacc0x2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF);
vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF);
vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF);
vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF);

w = (const int8_t*) w + 64;
k -= 4 * sizeof(int8_t);
}

vacc0x0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0x0123456789ABCDEF, 4);
vacc0x1x0123456789ABCDEF = _mm512_srai_epi32(vacc0x1x0123456789ABCDEF, 4);
vacc0x2x0123456789ABCDEF = _mm512_srai_epi32(vacc0x2x0123456789ABCDEF, 4);
__m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0x0123456789ABCDEF);
__m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x1x0123456789ABCDEF);
__m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x2x0123456789ABCDEF);
vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4);
vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4);
vacc2x0123456789ABCDEF = _mm512_srai_epi32(vacc2x0123456789ABCDEF, 4);
__m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF);
__m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF);
__m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF);

vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale));
vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale));
Expand Down
Loading

0 comments on commit 86e0a93

Please sign in to comment.