Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lukezhuz/mbb dataset #1

Merged
merged 4 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 109 additions & 27 deletions gematria/datasets/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,20 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/Error.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG

#ifdef DEBUG
#define LOG(X) \
llvm::errs() << X << "\n"
#else
#define LOG(X)
#endif

namespace gematria {
namespace {

Expand All @@ -66,7 +74,8 @@ BHiveImporter::BHiveImporter(const Canonicalizer* canonicalizer)
mc_inst_printer_(target_machine_.getTarget().createMCInstPrinter(
target_machine_.getTargetTriple(), kDefaultSyntax,
*target_machine_.getMCAsmInfo(), *target_machine_.getMCInstrInfo(),
*target_machine_.getMCRegisterInfo())) {}
*target_machine_.getMCRegisterInfo())),
MMI_(dynamic_cast<const llvm::LLVMTargetMachine*>(&target_machine_)) {}

absl::StatusOr<BasicBlockProto> BHiveImporter::BasicBlockProtoFromMachineCode(
llvm::ArrayRef<uint8_t> machine_code, uint64_t base_address /*= 0*/) {
Expand Down Expand Up @@ -149,50 +158,127 @@ absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseBHiveCsvLine(
return proto;
}

absl::StatusOr<BasicBlockProto> BHiveImporter::BasicBlockProtoFromMBBName(
std::string_view MBB_name, uint64_t base_address /*= 0*/) {
BasicBlockProto basic_block_proto;
// convert MBB_name to llvm::StringRef
llvm::StringRef MBB_name_ref(MBB_name.data(), MBB_name.size());

// lookup the MBB in the map, if not, return error
if (name_to_mbb_.find(MBB_name_ref) == name_to_mbb_.end()) {
LOG("Cannot find MBB, using key " << MBB_name);
return absl::InvalidArgumentError(
absl::StrCat("Could not find MBB with name ", MBB_name));
}

llvm::MachineBasicBlock* MBB = name_to_mbb_[MBB_name_ref];
LOG("MBB is " << *MBB);
for (llvm::MachineInstr& MI : *MBB){
// if MI is a control instruction(ret,branch,jmp), skip it
if (MI.isInlineAsm() || MI.isTerminator()) {
LOG("MI is a control instruction, skipping it " << MI);
continue;
}

// Assert MI cannot be a CALL instruction
assert(!MI.isCall() && "MI is a CALL instruction, bad dataset");
canonicalizer_.InstructionFromMachineInstr(MI);
// TODO: Add this to the basic block proto
}

// for (DisassembledInstruction& instruction : *instructions) {
// NOT VERY IMPORTANT THESE 3 LINES
// MachineInstructionProto& machine_instruction =
// *basic_block_proto.add_machine_instructions();
// machine_instruction.set_address(instruction.address);
// machine_instruction.set_assembly(instruction.assembly);
// machine_instruction.set_machine_code(instruction.machine_code);
// VERY IMPORTANT!!! Do this first TODO: change this to use the unique name to get the MBB
// *basic_block_proto.add_canonicalized_instructions() = ProtoFromInstruction(
// canonicalizer_.InstructionFromMCInst(instruction.mc_inst));
// }

return basic_block_proto;
}

absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseMIRCsvLine(
std::string_view source_name, std::string_view line,
size_t BB_name_index, size_t throughput_column_index,
double throughput_scaling /*= 1.0*/, uint64_t base_address /*= 0*/) {
const absl::InlinedVector<std::string_view, 2> columns =
absl::StrSplit(line, ',');
const int min_required_num_columns =
std::max(BB_name_index, throughput_column_index) + 1;
if (columns.size() < min_required_num_columns) {
return absl::InvalidArgumentError(absl::StrFormat(
"Expected `line` to have at least %d columns, found %d: %s",
min_required_num_columns, columns.size(), line));
}
if (BB_name_index == throughput_column_index) {
return absl::InvalidArgumentError(absl::StrFormat(
"Expected BB name column and throughput column indices to be "
"different, but were both %d: %s",
BB_name_index, line));
}
const std::string_view BB_unique_name =
columns[BB_name_index];
const std::string_view throughput_str = columns[throughput_column_index];

BasicBlockWithThroughputProto proto;

absl::StatusOr<BasicBlockProto> block_proto_or_status =
BasicBlockProtoFromMBBName(BB_unique_name, base_address);
if (!block_proto_or_status.ok()) return block_proto_or_status.status();
*proto.mutable_basic_block() = std::move(block_proto_or_status).value();

double throughput_cycles = 0.0;
if (!absl::SimpleAtod(throughput_str, &throughput_cycles)) {
return absl::InvalidArgumentError(
absl::StrCat("Could not parse throughput value ", throughput_str));
}

ThroughputWithSourceProto& throughput = *proto.add_inverse_throughputs();
throughput.set_source(source_name);
throughput.add_inverse_throughput_cycles(throughput_cycles *
throughput_scaling);

return proto;
}

absl::StatusOr<bool> BHiveImporter::LoadMIRModule(std::string_view file_name){
// clear previous loaded module
name_to_mbb_.clear();

// create MIR Parser and read all MBB to the map based on their unique name
llvm::LLVMContext context;
llvm::SMDiagnostic diag;

// Set attributes on functions as loaded from MIR from command line arguments.
// auto setMIRFunctionAttributes = [&CPUStr, &FeaturesStr](Function &F) {
// llvm::codegen::setFunctionAttributes(CPUStr, FeaturesStr, F);
// };

std::unique_ptr<llvm::MIRParser> mir_parser = llvm::createMIRParserFromFile(file_name, diag, context);
if (!mir_parser) {
diag.print("test ", llvm::WithColor::error(llvm::errs(), "test"));
mir_parser_ = llvm::createMIRParserFromFile(file_name, diag, llvm_context_);
if (!mir_parser_) {
return absl::InvalidArgumentError(
absl::StrCat("Could not create MIR parser for file ", file_name));
}

// Parse the LLVM IR module (if any)
std::unique_ptr<llvm::Module> mir_module = mir_parser->parseIRModule();
if (!mir_module) {
mir_module_ = mir_parser_->parseIRModule();
if (!mir_module_) {
// Handle error
return absl::InvalidArgumentError(
absl::StrCat("Could not parse MIR module for file ", file_name));
}

// Prepare MachineModuleInfo
auto *llvmTargetMachine = dynamic_cast<const llvm::LLVMTargetMachine*>(&target_machine_);
if (llvmTargetMachine == nullptr) {
return absl::InvalidArgumentError(
absl::StrCat("Could not cast target machine for file ", file_name));
}
llvm::MachineModuleInfo MMI(llvmTargetMachine);
MMI_.initialize();

// Parse the MachineFunctions and add them to MMI
if (mir_parser->parseMachineFunctions(*mir_module, MMI)) {
if (mir_parser_->parseMachineFunctions(*mir_module_, MMI_)) {
// Handle error
return absl::InvalidArgumentError(
absl::StrCat("Could not parse MachineFunctions for file ", file_name));
}

// Now iterate over the MachineFunctions and their MachineBasicBlocks
for (auto &F : *mir_module) {
for (auto &F : *mir_module_) {
if (F.isDeclaration()) continue;
llvm::MachineFunction &MF = MMI.getOrCreateMachineFunction(F);
llvm::MachineFunction &MF = MMI_.getOrCreateMachineFunction(F);
for (auto &MBB : MF) {
// assert name is unique
if (name_to_mbb_.find(MBB.getName()) != name_to_mbb_.end()) {
Expand All @@ -201,10 +287,6 @@ absl::StatusOr<bool> BHiveImporter::LoadMIRModule(std::string_view file_name){
} else {
name_to_mbb_[MBB.getName()] = &MBB;
}
// // Pretty print the machine block with its name
// llvm::outs() << "MachineBasicBlock: " << MBB.getName() << "\n";
// MBB.print(llvm::outs());
// llvm::outs() << "\n";
}
}

Expand Down
22 changes: 22 additions & 0 deletions gematria/datasets/bhive_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "llvm/CodeGen/MIRParser/MIRParser.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/CodeGen/MachineModuleInfo.h"

namespace gematria {

Expand Down Expand Up @@ -63,6 +64,9 @@ class BHiveImporter {
// corresponds to a three-byte sequence {0xAA, 0xBB, 0x11}.
absl::StatusOr<BasicBlockProto> BasicBlockProtoFromMachineCodeHex(
std::string_view machine_code_hex, uint64_t base_address = 0);

absl::StatusOr<BasicBlockProto> BasicBlockProtoFromMBBName(
std::string_view MBB_name, uint64_t base_address = 0);

// Parses a basic block with throughput from one BHive CSV line. Expects that
// the line has the format "{machine_code},{throughput}" where {machine_code}
Expand All @@ -80,13 +84,31 @@ class BHiveImporter {
absl::StatusOr<bool> LoadMIRModule(
std::string_view file_name
);

// Parses a MIR basic block with throughput from one BHive CSV line. Expects that
// the line has the format "{BB_name},{throughput}" where {machine_code}
// is the machine code of the basic block in the hex format accepted by
// ParseBasicBlockFromMachineCodeHex(), and {throughput} is the inverse
// throughput of the basic block in text format.
// Optionally applies `throughput_scaling` to the throughput value, and uses
// `base_address` as the address of the first instruction in the basic block.
// NOTE: YOU MUST RUN LoadMIRModule before calling this function
absl::StatusOr<BasicBlockWithThroughputProto> ParseMIRCsvLine(
std::string_view source_name, std::string_view line,
size_t BB_name_index, size_t throughput_column_index,
double throughput_scaling = 1.0, uint64_t base_address = 0);

private:
const Canonicalizer& canonicalizer_;
const llvm::TargetMachine& target_machine_;
std::unique_ptr<llvm::MCContext> context_;
std::unique_ptr<llvm::MCDisassembler> disassembler_;
std::unique_ptr<llvm::MCInstPrinter> mc_inst_printer_;
llvm::DenseMap<llvm::StringRef, llvm::MachineBasicBlock*> name_to_mbb_;
llvm::LLVMContext llvm_context_;
std::unique_ptr<llvm::Module> mir_module_;
llvm::MachineModuleInfo MMI_;
std::unique_ptr<llvm::MIRParser> mir_parser_;
};

} // namespace gematria
Expand Down
5 changes: 4 additions & 1 deletion gematria/datasets/bhive_importer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,12 @@ TEST_F(BHiveImporterTest, NonStandardColumns) {
})pb")));
}

TEST_F(BHiveImporterTest, LoadMIRModule) {
TEST_F(BHiveImporterTest, MIRDatasetBasicTest) {
EXPECT_THAT(x86_bhive_importer_->LoadMIRModule("/u9/z277zhu/research/gematria/sample_dataset/data.mir"),
IsOk());
EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine(kSourceName, "a,b,BB_13,2.37", 2,
3, kScaling),
IsOk());
}

} // namespace
Expand Down
2 changes: 2 additions & 0 deletions gematria/llvm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:X86UtilsAndDesc",
"@llvm-project//llvm:CodeGen",
],
)

Expand All @@ -65,6 +66,7 @@ cc_test(
"@com_google_googletest//:gtest_main",
"@llvm-project//llvm:MC",
"@llvm-project//llvm:ir_headers",
"@llvm-project//llvm:CodeGen",
],
)

Expand Down
Loading
Loading