Skip to content

Commit

Permalink
Merge pull request OpenMathLib#4883 from ChipKerchner/fixSGEMMUnitTes…
Browse files Browse the repository at this point in the history
…tZeroSize

Fix SBGEMM unit test to handle zero elements.
  • Loading branch information
martin-frbg authored Aug 17, 2024
2 parents f61930e + 89702e1 commit 4850275
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions test/compare_sgemm_sbgemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ float16to32 (bfloat16_bits f16)

#define SBGEMM_LARGEST 256

void *malloc_safe(size_t size)
{
if (size == 0)
return malloc(1);
else
return malloc(size);
}

int
main (int argc, char *argv[])
{
Expand All @@ -100,13 +108,13 @@ main (int argc, char *argv[])
{
if ((x > 100) && (x != SBGEMM_LARGEST)) continue;
m = k = n = x;
float *A = (float *)malloc(m * k * sizeof(FLOAT));
float *B = (float *)malloc(k * n * sizeof(FLOAT));
float *C = (float *)malloc(m * n * sizeof(FLOAT));
bfloat16_bits *AA = (bfloat16_bits *)malloc(m * k * sizeof(bfloat16_bits));
bfloat16_bits *BB = (bfloat16_bits *)malloc(k * n * sizeof(bfloat16_bits));
float *DD = (float *)malloc(m * n * sizeof(FLOAT));
float *CC = (float *)malloc(m * n * sizeof(FLOAT));
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits));
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits));
float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT));
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT));
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
(DD == NULL) || (CC == NULL))
return 1;
Expand Down Expand Up @@ -194,16 +202,16 @@ main (int argc, char *argv[])
return ret;
}

k = 1;
for (x = 1; x <= loop; x++)
{
float *A = (float *)malloc(x * x * sizeof(FLOAT));
float *B = (float *)malloc(x * sizeof(FLOAT));
float *C = (float *)malloc(x * sizeof(FLOAT));
bfloat16_bits *AA = (bfloat16_bits *)malloc(x * x * sizeof(bfloat16_bits));
bfloat16_bits *BB = (bfloat16_bits *)malloc(x * sizeof(bfloat16_bits));
float *DD = (float *)malloc(x * sizeof(FLOAT));
float *CC = (float *)malloc(x * sizeof(FLOAT));
k = (x == 0) ? 0 : 1;
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT));
float *B = (float *)malloc_safe(x * sizeof(FLOAT));
float *C = (float *)malloc_safe(x * sizeof(FLOAT));
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits));
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits));
float *DD = (float *)malloc_safe(x * sizeof(FLOAT));
float *CC = (float *)malloc_safe(x * sizeof(FLOAT));
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
(DD == NULL) || (CC == NULL))
return 1;
Expand Down

0 comments on commit 4850275

Please sign in to comment.