Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMX Extension #93

Draft
wants to merge 44 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
95f1c83
[ADD] AMX implementation and Sapphire rapids config
cyssi-cb Sep 20, 2024
1fdf36f
[FIX] update asmjit calls
cyssi-cb Sep 20, 2024
2b9adf9
[FIX] include Sapphire Rapids config
cyssi-cb Sep 20, 2024
a91ea44
[FIX] asmjit call in sapphire rapids config
cyssi-cb Sep 20, 2024
e2a8731
[FIX] add new files to cmake
cyssi-cb Sep 20, 2024
529bb7f
[FIX] adapted workload to new asmjit api
cyssi-cb Sep 20, 2024
b3b94f8
[FIX] adapted workload to new asmjit api
cyssi-cb Sep 20, 2024
8e8e82d
[FIX] add missing compiler flags
cyssi-cb Sep 20, 2024
892f234
[FIX] register SApphire Rapids config
cyssi-cb Sep 20, 2024
5c5f935
[REMOVED] unneded prints
cyssi-cb Sep 20, 2024
edf806c
[ADD] use bf16
cyssi-cb Sep 20, 2024
af302d6
[FIX] typo
cyssi-cb Sep 20, 2024
7caf6cb
[ADD] limit init value
cyssi-cb Sep 23, 2024
38fe61d
[REMOVED] unnecessary defines
cyssi-cb Sep 23, 2024
0427834
[ADD] use AVX512 config with AMX
cyssi-cb Sep 23, 2024
9fc69b6
[FIX] correct CPUID model, function as static, logging behavior
cyssi-cb Nov 18, 2024
b5fa958
[FIX] includes
cyssi-cb Nov 18, 2024
fb60acd
[FIX] includes
cyssi-cb Nov 18, 2024
5576797
[ADD] check for AMX during compilePayload
cyssi-cb Nov 18, 2024
ee707cf
[FIX] spelling
cyssi-cb Nov 18, 2024
6e66d26
[FIX] merge AMX into AVX512 workload with runtime check for AMX feature
cyssi-cb Nov 18, 2024
824520c
[FIX] CMakeLists
cyssi-cb Nov 18, 2024
87d6af9
[FIX] naming convention
cyssi-cb Nov 18, 2024
b775b8c
[FIX] spelling
cyssi-cb Nov 18, 2024
ce86848
[FIX] spelling
cyssi-cb Nov 18, 2024
3a789bf
[FIX] spelling
cyssi-cb Nov 18, 2024
83e91eb
[REMOVE] unnecesary payload definition, now merged into AVX512Payload…
cyssi-cb Nov 18, 2024
fdf01db
Merge branch 'master' into amx
cyssi-cb Nov 18, 2024
ce35b28
[FIX] Cmake
cyssi-cb Nov 18, 2024
b1577ce
[FIX] move __tilecfg definition into header
cyssi-cb Nov 18, 2024
15822f2
Merge remote-tracking branch 'origin/master' into cyrill.amx_integration
marenz2569 Nov 22, 2024
31226ad
Squash merge branch code-style-enforcing into cyrill.amx_integration
marenz2569 Nov 22, 2024
e73d1ef
Merge branch 'code-style-enforcing' into cyrill.amx_integration
marenz2569 Nov 22, 2024
27ce53e
reduce diff of merge
marenz2569 Nov 22, 2024
98efa54
Specialize AVX512 Payload for AMX extension
marenz2569 Nov 22, 2024
345ed8e
remove AMX instruction from default AVX512 payload
marenz2569 Nov 22, 2024
0396c64
fix merge error
marenz2569 Nov 22, 2024
c4e8ca6
minimize diff
marenz2569 Nov 22, 2024
287d87a
cleanup include
marenz2569 Nov 22, 2024
997565d
rename tileconfig
marenz2569 Nov 22, 2024
872ffa1
[ADD] small fixes
cyssi-cb Nov 25, 2024
a30f850
[ADD] small fixes>
cyssi-cb Nov 25, 2024
b4618e5
[FIX] casts and types
cyssi-cb Nov 25, 2024
13202bd
[FIX] cast and include
cyssi-cb Nov 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ git_submodule_update()

if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
else()
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -O2 -fdata-sections -ffunction-sections")
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mamx-tile -Wall -Wextra -O2 -fdata-sections -ffunction-sections")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will not be necessary when the ldtilecfg instruction is integrated into the assembler kernel

endif()

if(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
Expand Down
17 changes: 15 additions & 2 deletions include/firestarter/Environment/X86/Payload/AVX512Payload.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,17 @@

namespace firestarter::environment::x86::payload {

// Define struct that is used as config and loaded through ldtilecfg()
struct TileConfig {
uint8_t palette_id;
uint8_t start_row;
uint8_t reserved_0[14];
uint16_t colsb[16];
uint8_t rows[16];
};

/// This payload is designed for the AVX512 foundation CPU extension.
class AVX512Payload final : public X86Payload {
class AVX512Payload : public X86Payload {
public:
AVX512Payload() noexcept
: X86Payload(/*FeatureRequests=*/{asmjit::CpuFeatures::X86::kAVX512_F}, /*Name=*/"AVX512", /*RegisterSize=*/8,
Expand Down Expand Up @@ -59,9 +68,13 @@ class AVX512Payload final : public X86Payload {
/// \returns The compiled payload that provides access to the init and load functions.
[[nodiscard]] auto compilePayload(const environment::payload::PayloadSettings& Settings, bool DumpRegisters,
bool ErrorDetection) const
-> environment::payload::CompiledPayload::UniquePtr override;
-> environment::payload::CompiledPayload::UniquePtr final;

private:
static void create_AMX_config(TileConfig* tileinfo);
static void request_permission();
static void init_buffer_rand(uintptr_t buf1, uintptr_t buf2);

/// Function to initialize the memory used by the high load function.
/// \arg MemoryAddr The pointer to the memory.
/// \arg BufferSize The number of doubles that is allocated in MemoryAddr.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/******************************************************************************
* FIRESTARTER - A Processor Stress Test Utility
* Copyright (C) 2024 TU Dresden, Center for Information Services and High
* Performance Computing
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/\>.
*
* Contact: [email protected]
*****************************************************************************/

#pragma once

#include "firestarter/Environment/X86/Payload/AVX512Payload.hpp"

#include <asmjit/asmjit.h>

namespace firestarter::environment::x86::payload {

/// This payload is designed for the AVX512 foundation CPU extension specialized for AMX.
class AVX512WithAMXPayload : public AVX512Payload {
public:
AVX512WithAMXPayload() noexcept {
// Enable the AMX instruction in the AVX512 Payload and request AMX_TILE and AMX_BF16 feature.
addInstructionFlops("AMX", 512);
addFeatureRequest(asmjit::CpuFeatures::X86::kAMX_TILE);
addFeatureRequest(asmjit::CpuFeatures::X86::kAMX_BF16);
}
Comment on lines +33 to +38
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this wrapper to the AVX512Payload to allow for checking the AMX_TILE and AMX_BF16 features.

};
} // namespace firestarter::environment::x86::payload
14 changes: 14 additions & 0 deletions include/firestarter/Environment/X86/Payload/X86Payload.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ class X86Payload : public environment::payload::Payload {
};

protected:
/// Add another instruction to the InstructionFlops map. This makes shure this instruction is shown in the set of
/// available instructions and we can correctly calculate the FLOPS for it. It is required when specializing Payloads
/// with new CPU extensions (e.g. AVX512 with newer CPU Extension like AMX).
/// \arg InstructionName The name of the instruction.
/// \arg Flops The Flops that are computed for this function.
void addInstructionFlops(const std::string& InstructionName, unsigned Flops) {
InstructionFlops[InstructionName] = Flops;
}

/// Add another feature request of this payload after it has been initialized. This can be used to specialize Payloads
/// (e.g. AVX512 with newer CPU Extension like AMX) and add more requested features as needed.
/// \arg Request The requested Cpu Feature.
void addFeatureRequest(asmjit::CpuFeatures::X86::Id Request) { FeatureRequests.push_back(Request); }

/// Emit the code to dump the xmm, ymm or zmm registers into memory for the dump registers feature.
/// \tparam Vec the type of the vector register used.
/// \arg Cb The asmjit code builder that is used to emit the assembler code.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/******************************************************************************
* FIRESTARTER - A Processor Stress Test Utility
* Copyright (C) 2024 TU Dresden, Center for Information Services and High
* Performance Computing
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/\>.
*
* Contact: [email protected]
*****************************************************************************/

#pragma once

#include "firestarter/Environment/X86/Payload/AVX512WithAMXPayload.hpp"
#include "firestarter/Environment/X86/Platform/X86PlatformConfig.hpp"

namespace firestarter::environment::x86::platform {
class SapphireRapidsConfig final : public X86PlatformConfig {
public:
SapphireRapidsConfig() noexcept
: X86PlatformConfig(/*Name=*/"SPR_XEONEP", /*Family=*/6, /*Models=*/{143},
/*Settings=*/
environment::payload::PayloadSettings(/*Threads=*/{1, 2},
/*DataCacheBufferSize=*/{32768, 1048576, 1441792},
/*RamBufferSize=*/1048576000, /*Lines=*/1536,
/*InstructionGroups=*/
{{"RAM_S", 3},
{"RAM_P", 1},
{"L3_S", 1},
{"L3_P", 1},
{"L2_S", 4},
{"L2_L", 70},
{"L1_S", 0},
{"L1_L", 40},
{"REG", 140},
{"AMX", 1}}),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These values should be updated. E.g. the L1 data cache size changed from SkylakeSP to Sapphire Rapids

/*Payload=*/std::make_shared<const payload::AVX512WithAMXPayload>()) {}
};
} // namespace firestarter::environment::x86::platform
15 changes: 9 additions & 6 deletions include/firestarter/Environment/X86/X86Environment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "firestarter/Environment/X86/Platform/RomeConfig.hpp"
#include "firestarter/Environment/X86/Platform/SandyBridgeConfig.hpp"
#include "firestarter/Environment/X86/Platform/SandyBridgeEPConfig.hpp"
#include "firestarter/Environment/X86/Platform/SapphireRapidsConfig.hpp"
#include "firestarter/Environment/X86/Platform/SkylakeConfig.hpp"
#include "firestarter/Environment/X86/Platform/SkylakeSPConfig.hpp"
#include "firestarter/Environment/X86/Platform/X86PlatformConfig.hpp"
Expand Down Expand Up @@ -102,7 +103,8 @@ class X86Environment final : public Environment {
std::make_shared<platform::HaswellEPConfig>(), std::make_shared<platform::SandyBridgeConfig>(),
std::make_shared<platform::SandyBridgeEPConfig>(), std::make_shared<platform::NehalemConfig>(),
std::make_shared<platform::NehalemEPConfig>(), std::make_shared<platform::BulldozerConfig>(),
std::make_shared<platform::NaplesConfig>(), std::make_shared<platform::RomeConfig>()};
std::make_shared<platform::NaplesConfig>(), std::make_shared<platform::RomeConfig>(),
std::make_shared<platform::SapphireRapidsConfig>()};

/// The list of configs that are fallbacks. If none of the PlatformConfigs is the default one on the current CPU, we
/// select the first one from this list that is available on the current system. If multiple configs can be available
Expand All @@ -111,11 +113,12 @@ class X86Environment final : public Environment {
/// AVX512 takes precedence. This list should contain one entry for each of the supported CPU extensions by the
/// FIRESTARTER payloads.
const std::list<std::shared_ptr<platform::X86PlatformConfig>> FallbackPlatformConfigs = {
std::make_shared<platform::SkylakeSPConfig>(), // AVX512
std::make_shared<platform::BulldozerConfig>(), // FMA4
std::make_shared<platform::HaswellConfig>(), // FMA
std::make_shared<platform::SandyBridgeConfig>(), // AVX
std::make_shared<platform::NehalemConfig>() // SSE2
std::make_shared<platform::SapphireRapidsConfig>(), // AVX512 with AMX
std::make_shared<platform::SkylakeSPConfig>(), // AVX512
std::make_shared<platform::BulldozerConfig>(), // FMA4
std::make_shared<platform::HaswellConfig>(), // FMA
std::make_shared<platform::SandyBridgeConfig>(), // AVX
std::make_shared<platform::NehalemConfig>() // SSE2
};
};

Expand Down
125 changes: 123 additions & 2 deletions src/firestarter/Environment/X86/Payload/AVX512Payload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,34 @@
#include "firestarter/Environment/X86/Payload/AVX512Payload.hpp"
#include "firestarter/Environment/X86/Payload/CompiledX86Payload.hpp"

#include <asm/prctl.h> /* Definition of ARCH_* constants */
#include <immintrin.h>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to remove these includes

#include <sys/syscall.h>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also include guard for this header. See the comment in the request_permission function


#define XFEATURE_XTILECFG 17
Copy link
Member Author

@marenz2569 marenz2569 Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use a enum class for constants or constexpr const... if its just a single constant

#define XFEATURE_XTILEDATA 18
#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG)
#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA)
#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA)

#define ARCH_GET_XCOMP_PERM 0x1022
#define ARCH_REQ_XCOMP_PERM 0x1023

#define MAX 1024
#define MAX_ROWS 16
#define MAX_COLS 64

namespace firestarter::environment::x86::payload {

auto AVX512Payload::compilePayload(const environment::payload::PayloadSettings& Settings, bool DumpRegisters,
bool ErrorDetection) const -> environment::payload::CompiledPayload::UniquePtr {
using Imm = asmjit::Imm;
using Tmm = asmjit::x86::Tmm;
using Zmm = asmjit::x86::Zmm;
// NOLINTBEGIN(readability-identifier-naming)
constexpr auto tmm6 = asmjit::x86::tmm6;
constexpr auto tmm7 = asmjit::x86::tmm7;

constexpr asmjit::x86::Mem (*zmmword_ptr)(const asmjit::x86::Gp&, int32_t) = asmjit::x86::zmmword_ptr;
constexpr auto zmm0 = asmjit::x86::zmm0;
constexpr auto zmm1 = asmjit::x86::zmm1;
Expand Down Expand Up @@ -158,6 +179,42 @@ auto AVX512Payload::compilePayload(const environment::payload::PayloadSettings&
for (auto const& Reg : ShiftReg32) {
Cb.mov(Reg, Imm(0xAAAAAAAA));
}

// Init AMX registers and config
TileConfig tile_data = {0};
request_permission();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be called on the start of the assembler kernel. You should be able to provide the pointer to a static function in asmjit and call it

create_AMX_config(&tile_data); // Create tilecfg and fill it
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for this one


static bool init = true;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused variable

uintptr_t src1, src2;
uint64_t src3;
unsigned int aligned_alloc_size = static_cast<unsigned int>(MAX * sizeof(__bfloat16));
if (aligned_alloc_size % 1024) { // aligned_alloc expects size to be multiple of alignment (aka 1024)
aligned_alloc_size = aligned_alloc_size + (1024 - (aligned_alloc_size % 1024));
}
src1 = (uintptr_t)aligned_alloc(1024, aligned_alloc_size);
src2 = (uintptr_t)aligned_alloc(1024, aligned_alloc_size);
src3 = (uint64_t)aligned_alloc(1024, aligned_alloc_size);
if (((void*)src1 == nullptr) || (void*)src2 == nullptr ||
(void*)src3 == nullptr) { // uintptr_t garantuees we can cast it to void* and back
std::cout << "[ERROR]: Allocation of source and target buffer for AMX failed. Aborting...\n";
exit(1);
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This memory should be allocated in the LoadWorkerMemory class. A platform independent alligned alloc abstraction is available there. You might need to change it sightly, so that these variables are aligned with 1024B instead of 64B. This change will not only allocate these arrays for the AVX512/AMX payload but for all. There should however be no negative effect other than increased allocated RAM size for all payloads.


// Init buffers
init_buffer_rand(src1, src2);
memset((void*)src3, 0, aligned_alloc_size);

Cb.tileloaddt1(tmm6, asmjit::x86::ptr(src1));
Cb.tileloaddt1(tmm7, asmjit::x86::ptr(src2)); // Ensure no overflows through loading x and -x in src2

Cb.tileloaddt1(asmjit::x86::tmm0, asmjit::x86::ptr(src3)); // Preload with 0
Cb.tileloaddt1(asmjit::x86::tmm1, asmjit::x86::ptr(src3));
Cb.tileloaddt1(asmjit::x86::tmm2, asmjit::x86::ptr(src3));
Cb.tileloaddt1(asmjit::x86::tmm3, asmjit::x86::ptr(src3));
Cb.tileloaddt1(asmjit::x86::tmm4, asmjit::x86::ptr(src3));
Cb.tileloaddt1(asmjit::x86::tmm5, asmjit::x86::ptr(src3));

// Initialize AVX512-Registers for FMA Operations
Cb.vmovapd(zmm0, zmmword_ptr(PointerReg, 0));
Cb.vmovapd(zmm1, zmmword_ptr(PointerReg, 64));
Expand Down Expand Up @@ -196,6 +253,7 @@ auto AVX512Payload::compilePayload(const environment::payload::PayloadSettings&
auto AddDest = AddStart + 1;
auto MovDst = TransStart;
unsigned L1Offset = 0;
unsigned AmxRegisterCounter = 0;

const auto L1Increment = [&Cb, &L1Offset, &L1Size, &L1Addr, &OffsetReg, &PointerReg]() {
L1Offset += 64;
Expand All @@ -212,7 +270,9 @@ auto AVX512Payload::compilePayload(const environment::payload::PayloadSettings&

for (auto Count = 0U; Count < Repetitions; Count++) {
for (const auto& Item : Sequence) {
if (Item == "REG") {
if (Item == "AMX") {
Cb.tdpbf16ps(Tmm(AmxRegisterCounter++ % 6), tmm6, tmm7);
} else if (Item == "REG") {
Cb.vfmadd231pd(Zmm(AddDest), zmm0, zmm2);
Cb.vfmadd231pd(Zmm(MovDst), zmm2, zmm1);
Cb.xor_(ShiftReg[(ShiftPos + NrShiftRegs - 1) % NrShiftRegs], TempReg);
Expand Down Expand Up @@ -388,4 +448,65 @@ void AVX512Payload::init(double* MemoryAddr, uint64_t BufferSize) const {
X86Payload::initMemory(MemoryAddr, BufferSize, 0.27948995982e-4, 0.27948995982e-4);
}

} // namespace firestarter::environment::x86::payload
void AVX512Payload::create_AMX_config(TileConfig* tileinfo) {
// Create tile_cfg, fill it and return
int i;
tileinfo->palette_id = 1;
tileinfo->start_row = 0;

for (i = 0; i < 8; ++i) {
tileinfo->colsb[i] = MAX_COLS;
tileinfo->rows[i] = MAX_ROWS;
}

_tile_loadconfig(tileinfo);
}

void AVX512Payload::request_permission() {

long rc;
unsigned long bitmask;
Copy link
Member Author

@marenz2569 marenz2569 Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that the syscall is required to enable AMX on the OS level. This however will cause to not compile on Windows and MacOS. Please guard this code with an linux ifdef and workerLog::fatal on Windows/MacOS.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking a look at this PR microsoft/onnxruntime#14042 it seems that AMX is just supported on Windows. Only question is if the compiler sets some flag in the binary for the operating system or if it just works with Jit generated assembler code.

rc = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);

if (rc) {
workerLog::error() << "XTILE_DATA request failed: " << rc;
}

rc = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask);
if (rc) {
workerLog::error() << "prctl(ARCH_GET_XCOMP_PERM) error: " << rc;
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this check if XFEATURE_MASK_XTILEDATA is set? This check returns true if either/and XFEATURE_MASK_XTILECFG and XFEATURE_MASK_XTILEDATA is set.

if (bitmask & XFEATURE_MASK_XTILE) {
workerLog::trace() << "ARCH_REQ_XCOMP_PERM XTILE_DATA successful.";
} else {
workerLog::error() << "[ERROR] ARCH_REQ_XCOMP_PERM XTILE_DATA unsuccessful!";
}
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pointers to the function should be __bfloat16* This may however not be supported by all compilers. We might need to initialize this memory differently.

void AVX512Payload::init_buffer_rand(uintptr_t src1, uintptr_t src2) {

// Initialize buffer with random values
// Multiplication always produces either 1 or -1
// Accumulation operation always on (1 + -1) = 0 ensures stable values

__bfloat16* buf1 = (__bfloat16*)src1;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not use c-style casts

__bfloat16* buf2 = (__bfloat16*)src2;

// TODO: Change MAX_ROWS/MAXC_COLS from constant to maximum size check by asmJit
// Currently not supported by asmJit
// Alternative: Manually parse CPUID

for (int i = 0; i < MAX_ROWS; i++) {
__bfloat16 random_init = (__bfloat16)(rand() % 65536); // Limit maximum size as 1/x needs to fit bfloat16
for (int j = 0; j < MAX_COLS; j++) {
buf1[i * MAX_COLS + j] = (__bfloat16)(random_init);
if (!(j % 2)) {
buf2[i * MAX_COLS + j] = (__bfloat16)((-1) / random_init);
} else if (j % 2) {
buf2[i * MAX_COLS + j] = (__bfloat16)(1 / random_init);
}
}
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be called in the AVX512WithAMXPayload::init function. It should override and call the AVX512Payload::init. The pointer to memory is the same as in the assembler kernel.

}

} // namespace firestarter::environment::x86::payload
Loading