Skip to content

Commit

Permalink
Change malloc zero to return one byte and update the SBGEMM test to a…
Browse files Browse the repository at this point in the history
…gain use sizes of zero.
  • Loading branch information
ChipKerchner committed Aug 16, 2024
1 parent b1802f4 commit 868aa85
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions test/compare_sgemm_sbgemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ float16to32 (bfloat16_bits f16)

#define SBGEMM_LARGEST 256

void *malloc_safe(size_t size)
{
if (size == 0)
return malloc(1);
else
return malloc(size);
}

int
main (int argc, char *argv[])
{
Expand All @@ -96,17 +104,17 @@ main (int argc, char *argv[])
char transA = 'N', transB = 'N';
float alpha = 1.0, beta = 0.0;

for (x = 1; x <= loop; x++)
for (x = 0; x <= loop; x++)
{
if ((x > 100) && (x != SBGEMM_LARGEST)) continue;
m = k = n = x;
float *A = (float *)malloc(m * k * sizeof(FLOAT));
float *B = (float *)malloc(k * n * sizeof(FLOAT));
float *C = (float *)malloc(m * n * sizeof(FLOAT));
bfloat16_bits *AA = (bfloat16_bits *)malloc(m * k * sizeof(bfloat16_bits));
bfloat16_bits *BB = (bfloat16_bits *)malloc(k * n * sizeof(bfloat16_bits));
float *DD = (float *)malloc(m * n * sizeof(FLOAT));
float *CC = (float *)malloc(m * n * sizeof(FLOAT));
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits));
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits));
float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT));
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT));
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
(DD == NULL) || (CC == NULL))
return 1;
Expand Down Expand Up @@ -195,15 +203,15 @@ main (int argc, char *argv[])
}

k = 1;
for (x = 1; x <= loop; x++)
for (x = 0; x <= loop; x++)
{
float *A = (float *)malloc(x * x * sizeof(FLOAT));
float *B = (float *)malloc(x * sizeof(FLOAT));
float *C = (float *)malloc(x * sizeof(FLOAT));
bfloat16_bits *AA = (bfloat16_bits *)malloc(x * x * sizeof(bfloat16_bits));
bfloat16_bits *BB = (bfloat16_bits *)malloc(x * sizeof(bfloat16_bits));
float *DD = (float *)malloc(x * sizeof(FLOAT));
float *CC = (float *)malloc(x * sizeof(FLOAT));
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT));
float *B = (float *)malloc_safe(x * sizeof(FLOAT));
float *C = (float *)malloc_safe(x * sizeof(FLOAT));
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits));
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits));
float *DD = (float *)malloc_safe(x * sizeof(FLOAT));
float *CC = (float *)malloc_safe(x * sizeof(FLOAT));
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
(DD == NULL) || (CC == NULL))
return 1;
Expand Down

0 comments on commit 868aa85

Please sign in to comment.