Skip to content

Commit

Permalink
Fix MlasSgemmKernel: properly process more than 2 rows (microsoft#22125)
Browse files Browse the repository at this point in the history
This change fixes multiple tests like QDQTransformerTests.MatMul_U8S8S8,
for all architectures where architecture-specific
optimized function is not available yet, like s390x.

### Description
Matrix B is packed by 16 elements, thus new row starts 16 items later.
Also, for next C increment index only by 1 for each increment of C.


### Motivation and Context
This change fixes mlas sgemm fallback implementation for all
architectures which don't have architecture-specific implementations
available, like s390x.
  • Loading branch information
AlekseiNikiforovIBM authored Nov 21, 2024
1 parent 712bee1 commit e430795
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ Return Value:

#endif

int countb = 0;

do {

float BElements00;
Expand Down Expand Up @@ -116,6 +118,7 @@ Return Value:
//

const float* a = A;
const float* b = B;
size_t k = CountK;

while (k >= 2) {
Expand All @@ -128,10 +131,10 @@ Return Value:
Row1AElements1 = a[lda + 1];
}

BElements00 = B[0];
BElements01 = B[1];
BElements02 = B[2];
BElements03 = B[3];
BElements00 = b[0];
BElements01 = b[1];
BElements02 = b[2];
BElements03 = b[3];
Row0Block00 = Row0Block00 + BElements00 * Row0AElements0;
Row0Block01 = Row0Block01 + BElements01 * Row0AElements0;
Row0Block02 = Row0Block02 + BElements02 * Row0AElements0;
Expand All @@ -144,10 +147,10 @@ Return Value:
Row1Block03 = Row1Block03 + BElements03 * Row1AElements0;
}

BElements00 = B[4];
BElements01 = B[5];
BElements02 = B[6];
BElements03 = B[7];
BElements00 = b[16];
BElements01 = b[17];
BElements02 = b[18];
BElements03 = b[19];
Row0Block00 = Row0Block00 + BElements00 * Row0AElements1;
Row0Block01 = Row0Block01 + BElements01 * Row0AElements1;
Row0Block02 = Row0Block02 + BElements02 * Row0AElements1;
Expand All @@ -161,7 +164,7 @@ Return Value:
}

a += 2;
B += 8;
b += 32;
k -= 2;
}

Expand All @@ -173,10 +176,10 @@ Return Value:
Row1AElements0 = a[lda];
}

BElements00 = B[0];
BElements01 = B[1];
BElements02 = B[2];
BElements03 = B[3];
BElements00 = b[0];
BElements01 = b[1];
BElements02 = b[2];
BElements03 = b[3];
Row0Block00 = Row0Block00 + BElements00 * Row0AElements0;
Row0Block01 = Row0Block01 + BElements01 * Row0AElements0;
Row0Block02 = Row0Block02 + BElements02 * Row0AElements0;
Expand All @@ -188,8 +191,6 @@ Return Value:
Row1Block02 = Row1Block02 + BElements02 * Row1AElements0;
Row1Block03 = Row1Block03 + BElements03 * Row1AElements0;
}

B += 4;
}

//
Expand Down Expand Up @@ -295,9 +296,14 @@ Return Value:
break;
}

B += 4;
C += 4;
CountN -= 4;

countb = (countb + 1) % 4;
if (countb == 0) {
B += CountK * 16 - 16;
}
} while (CountN > 0);

return ProcessTwoRows ? 2 : 1;
Expand Down

0 comments on commit e430795

Please sign in to comment.