From 17b0caa2b2bff439feb6d2b39cfa16e7591882b0 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Thu, 14 Jul 2022 17:55:34 -0500 Subject: [PATCH] Fixed out-of-bounds read in haswell gemmsup kernels. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Details: - Fixed memory access bugs in the bli_sgemmsup_rv_haswell_asm_Mx2() kernels, where M = {1,2,3,4,5,6}. The bugs were caused by loading four single-precision elements of C, via instructions such as: vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) in situations where only two elements are guaranteed to exist. (These bugs may not have manifested in earlier tests due to the leading dimension alignment that BLIS employs by default.) The issue was fixed by replacing lines like the one above with: vmovsd(mem(rcx), xmm0) vfmadd231ps(xmm0, xmm3, xmm4) Thus, we use vmovsd to explicitly load only two elements of C into registers, and then operate on those values using register addressing. Thanks to Daniël de Kok for reporting these bugs in #635, and to Bhaskar Nallani for proposing the fix). - CREDITS file update. --- CREDITS | 1 + .../s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c | 63 ++++++++++++------- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/CREDITS b/CREDITS index 43c7b3ed53..bb2b3798fc 100644 --- a/CREDITS +++ b/CREDITS @@ -23,6 +23,7 @@ but many others have contributed code and feedback, including Dilyn Corner @dilyn-corner Mat Cross @matcross (NAG) @decandia50 + Daniël de Kok @danieldk (Explosion) Kay Dewhurst @jkd2016 (Max Planck Institute, Halle, Germany) Jeff Diamond (Oracle) Johannes Dieterich @iotamudelta diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c index 53a70d15f0..efb3363950 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c @@ -389,32 +389,38 @@ void bli_sgemmsup_rv_haswell_asm_6x2 label(.SROWSTORED) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) vmovsd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm14) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm14) vmovsd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) @@ -848,27 +854,32 @@ void bli_sgemmsup_rv_haswell_asm_5x2 label(.SROWSTORED) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) vmovsd(xmm12, mem(rcx, 0*32)) //add(rdi, rcx) @@ -1288,22 +1299,26 @@ void bli_sgemmsup_rv_haswell_asm_4x2 label(.SROWSTORED) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) //add(rdi, rcx) @@ -1683,17 +1698,20 @@ void bli_sgemmsup_rv_haswell_asm_3x2 label(.SROWSTORED) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2066,12 +2084,14 @@ void bli_sgemmsup_rv_haswell_asm_2x2 label(.SROWSTORED) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2404,7 +2424,8 @@ void bli_sgemmsup_rv_haswell_asm_1x2 label(.SROWSTORED) - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) //add(rdi, rcx)