Skip to content

Commit

Permalink
[FIX] casts and types
Browse files Browse the repository at this point in the history
  • Loading branch information
cyssi-cb committed Nov 25, 2024
1 parent a30f850 commit b4618e5
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions src/firestarter/Environment/X86/Payload/AVX512Payload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,10 @@ auto AVX512Payload::compilePayload(const environment::payload::PayloadSettings&
request_permission();
create_AMX_config(&tile_data); // Create tilecfg and fill it

uintptr_t src1, src2;
uint64_t src3;
unsigned int aligned_alloc_size = static_cast<unsigned int>(MaxSize::ELEMENTS * sizeof(__bfloat16));
__bfloat16* src1;
__bfloat16* src2;
uintptr_t src3;
unsigned int aligned_alloc_size = static_cast<unsigned int>(static_cast<uint32_t>(MaxSize::ELEMENTS) * sizeof(__bfloat16));
if (aligned_alloc_size % 1024) { // aligned_alloc expects size to be multiple of alignment (aka 1024)
aligned_alloc_size = aligned_alloc_size + (1024 - (aligned_alloc_size % 1024));
}
Expand All @@ -206,15 +207,15 @@ auto AVX512Payload::compilePayload(const environment::payload::PayloadSettings&
init_buffer_rand(src1, src2);
memset(static_cast<void*>(src3), 0, aligned_alloc_size);

Cb.tileloaddt1(tmm6, asmjit::x86::ptr(src1));
Cb.tileloaddt1(tmm7, asmjit::x86::ptr(src2)); // Ensure no overflows through loading x and -x in src2
Cb.tileloaddt1(tmm6, asmjit::x86::ptr(reinterpret_cast<uintptr_t>(src1)));
Cb.tileloaddt1(tmm7, asmjit::x86::ptr(reinterpret_cast<uintptr_t>(src2))); // Ensure no overflows through loading x and -x in src2

Cb.tileloaddt1(asmjit::x86::tmm0, asmjit::x86::ptr(src3)); // Preload with 0
Cb.tileloaddt1(asmjit::x86::tmm1, asmjit::x86::ptr(src3));
Cb.tileloaddt1(asmjit::x86::tmm2, asmjit::x86::ptr(src3));
Cb.tileloaddt1(asmjit::x86::tmm3, asmjit::x86::ptr(src3));
Cb.tileloaddt1(asmjit::x86::tmm4, asmjit::x86::ptr(src3));
Cb.tileloaddt1(asmjit::x86::tmm5, asmjit::x86::ptr(src3));
Cb.tileloaddt1(asmjit::x86::tmm0, asmjit::x86::ptr(reinterpret_cast<uintptr_t>(src3))); // Preload with 0
Cb.tileloaddt1(asmjit::x86::tmm1, asmjit::x86::ptr(reinterpret_cast<uintptr_t>(src3)));
Cb.tileloaddt1(asmjit::x86::tmm2, asmjit::x86::ptr(reinterpret_cast<uintptr_t>(src3)));
Cb.tileloaddt1(asmjit::x86::tmm3, asmjit::x86::ptr(reinterpret_cast<uintptr_t>(src3)));
Cb.tileloaddt1(asmjit::x86::tmm4, asmjit::x86::ptr(reinterpret_cast<uintptr_t>(src3)));
Cb.tileloaddt1(asmjit::x86::tmm5, asmjit::x86::ptr(reinterpret_cast<uintptr_t>(src3)));

// Initialize AVX512-Registers for FMA Operations
Cb.vmovapd(zmm0, zmmword_ptr(PointerReg, 0));
Expand Down Expand Up @@ -456,8 +457,8 @@ void AVX512Payload::create_AMX_config(TileConfig* tileinfo) {
tileinfo->start_row = 0;

for (i = 0; i < 8; ++i) {
tileinfo->colsb[i] = MaxSize::COLS;
tileinfo->rows[i] = MaxSize::ROWS;
tileinfo->colsb[i] = static_cast<uint16_t>(MaxSize::COLS);
tileinfo->rows[i] = static_cast<uint8_t>(MaxSize::ROWS);
}

_tile_loadconfig(tileinfo);
Expand Down Expand Up @@ -496,14 +497,14 @@ void AVX512Payload::init_buffer_rand(__bfloat16* src1, __bfloat16* src2) {
// Currently not supported by asmJit
// Alternative: Manually parse CPUID

for (int i = 0; i < MaxSize::ROWS; i++) {
for (int i = 0; i < static_cast<uint8_t>(MaxSize::ROWS); i++) {
__bfloat16 random_init = static_cast<__bfloat16>((rand() % 65536)); // Limit maximum size as 1/x needs to fit bfloat16
for (int j = 0; j < MaxSize::COLS; j++) {
buf1[i * MaxSize::COLS + j] = static_cast<__bfloat16>(random_init);
for (int j = 0; j < static_cast<uint16_t>(MaxSize::COLS); j++) {
buf1[i * static_cast<uint16_t>(MaxSize::COLS) + j] = static_cast<__bfloat16>(random_init);
if (!(j % 2)) {
buf2[i * MaxSize::COLS + j] = static_cast<__bfloat16>(((-1) / random_init));
buf2[i * static_cast<uint16_t>(MaxSize::COLS) + j] = static_cast<__bfloat16>(((-1) / random_init));
} else if (j % 2) {
buf2[i * MaxSize::COLS + j] = static_cast<__bfloat16>((1 / random_init));
buf2[i * static_cast<uint16_t>(MaxSize::COLS) + j] = static_cast<__bfloat16>((1 / random_init));
}
}
}
Expand Down

0 comments on commit b4618e5

Please sign in to comment.