Skip to content

Commit

Permalink
Fixed out-of-bounds read in haswell gemmsup kernels.
Browse files Browse the repository at this point in the history
Details:
- Fixed memory access bugs in the bli_sgemmsup_rv_haswell_asm_Mx2()
  kernels, where M = {1,2,3,4,5,6}. The bugs were caused by loading four
  single-precision elements of C, via instructions such as:

	vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)

  in situations where only two elements are guaranteed to exist. (These
  bugs may not have manifested in earlier tests due to the leading
  dimension alignment that BLIS employs by default.) The issue was fixed
  by replacing lines like the one above with:

	vmovsd(mem(rcx), xmm0)
	vfmadd231ps(xmm0, xmm3, xmm4)

  Thus, we use vmovsd to explicitly load only two elements of C into
  registers, and then operate on those values using register addressing.
  Thanks to Daniël de Kok for reporting these bugs in flame#635, and to
  Bhaskar Nallani for proposing the fix).
- CREDITS file update.
  • Loading branch information
fgvanzee committed Jul 14, 2022
1 parent cc260fd commit 17b0caa
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 21 deletions.
1 change: 1 addition & 0 deletions CREDITS
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ but many others have contributed code and feedback, including
Dilyn Corner @dilyn-corner
Mat Cross @matcross (NAG)
@decandia50
Daniël de Kok @danieldk (Explosion)
Kay Dewhurst @jkd2016 (Max Planck Institute, Halle, Germany)
Jeff Diamond (Oracle)
Johannes Dieterich @iotamudelta
Expand Down
63 changes: 42 additions & 21 deletions kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c
Original file line number Diff line number Diff line change
Expand Up @@ -389,32 +389,38 @@ void bli_sgemmsup_rv_haswell_asm_6x2
label(.SROWSTORED)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm6)
vmovsd(xmm6, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm8)
vmovsd(xmm8, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm10)
vmovsd(xmm10, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm12)
vmovsd(xmm12, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm14)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm14)
vmovsd(xmm14, mem(rcx, 0*32))
//add(rdi, rcx)

Expand Down Expand Up @@ -848,27 +854,32 @@ void bli_sgemmsup_rv_haswell_asm_5x2
label(.SROWSTORED)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm6)
vmovsd(xmm6, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm8)
vmovsd(xmm8, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm10)
vmovsd(xmm10, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm12)
vmovsd(xmm12, mem(rcx, 0*32))
//add(rdi, rcx)

Expand Down Expand Up @@ -1288,22 +1299,26 @@ void bli_sgemmsup_rv_haswell_asm_4x2
label(.SROWSTORED)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm6)
vmovsd(xmm6, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm8)
vmovsd(xmm8, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm10)
vmovsd(xmm10, mem(rcx, 0*32))
//add(rdi, rcx)

Expand Down Expand Up @@ -1683,17 +1698,20 @@ void bli_sgemmsup_rv_haswell_asm_3x2
label(.SROWSTORED)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm6)
vmovsd(xmm6, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm8)
vmovsd(xmm8, mem(rcx, 0*32))
//add(rdi, rcx)

Expand Down Expand Up @@ -2066,12 +2084,14 @@ void bli_sgemmsup_rv_haswell_asm_2x2
label(.SROWSTORED)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx, 0*32))
add(rdi, rcx)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm6)
vmovsd(xmm6, mem(rcx, 0*32))
//add(rdi, rcx)

Expand Down Expand Up @@ -2404,7 +2424,8 @@ void bli_sgemmsup_rv_haswell_asm_1x2
label(.SROWSTORED)


vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx, 0*32))
//add(rdi, rcx)

Expand Down

0 comments on commit 17b0caa

Please sign in to comment.