Skip to content

Commit

Permalink
Fixed module storage bug: expand life time of llvm context and module…
Browse files Browse the repository at this point in the history
… from LoadMIRModule
  • Loading branch information
9Tempest committed Nov 5, 2023
1 parent c535b58 commit 790d19f
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 47 deletions.
88 changes: 47 additions & 41 deletions gematria/datasets/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,20 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/Error.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG

#ifdef DEBUG
#define LOG(X) \
llvm::errs() << X << "\n"
#else
#define LOG(X)
#endif

namespace gematria {
namespace {

Expand All @@ -66,7 +74,8 @@ BHiveImporter::BHiveImporter(const Canonicalizer* canonicalizer)
mc_inst_printer_(target_machine_.getTarget().createMCInstPrinter(
target_machine_.getTargetTriple(), kDefaultSyntax,
*target_machine_.getMCAsmInfo(), *target_machine_.getMCInstrInfo(),
*target_machine_.getMCRegisterInfo())) {}
*target_machine_.getMCRegisterInfo())),
MMI_(dynamic_cast<const llvm::LLVMTargetMachine*>(&target_machine_)) {}

absl::StatusOr<BasicBlockProto> BHiveImporter::BasicBlockProtoFromMachineCode(
llvm::ArrayRef<uint8_t> machine_code, uint64_t base_address /*= 0*/) {
Expand Down Expand Up @@ -152,15 +161,30 @@ absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseBHiveCsvLine(
absl::StatusOr<BasicBlockProto> BHiveImporter::BasicBlockProtoFromMBBName(
std::string_view MBB_name, uint64_t base_address /*= 0*/) {
BasicBlockProto basic_block_proto;
// llvm::Expected<std::vector<DisassembledInstruction>> instructions =
// DisassembleAllInstructions(*disassembler_,
// *target_machine_.getMCInstrInfo(),
// *target_machine_.getMCRegisterInfo(),
// *target_machine_.getMCSubtargetInfo(),
// *mc_inst_printer_, base_address, machine_code);
// if (llvm::Error error = instructions.takeError()) {
// return LlvmErrorToStatus(std::move(error));
// }
// convert MBB_name to llvm::StringRef
llvm::StringRef MBB_name_ref(MBB_name.data(), MBB_name.size());

// lookup the MBB in the map, if not, return error
if (name_to_mbb_.find(MBB_name_ref) == name_to_mbb_.end()) {
LOG("Cannot find MBB, using key " << MBB_name);
return absl::InvalidArgumentError(
absl::StrCat("Could not find MBB with name ", MBB_name));
}

llvm::MachineBasicBlock* MBB = name_to_mbb_[MBB_name_ref];
LOG("MBB is " << *MBB);
for (llvm::MachineInstr& MI : *MBB){
// if MI is a control instruction(ret,branch,jmp), skip it
if (MI.isInlineAsm() || MI.isTerminator()) {
LOG("MI is a control instruction, skipping it " << MI);
continue;
}

// Assert MI cannot be a CALL instruction
assert(!MI.isCall() && "MI is a CALL instruction, bad dataset");
canonicalizer_.InstructionFromMachineInstr(MI);
// TODO: Add this to the basic block proto
}

// for (DisassembledInstruction& instruction : *instructions) {
// NOT VERY IMPORTANT THESE 3 LINES
Expand Down Expand Up @@ -202,11 +226,10 @@ absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseMIRCsvLine(

BasicBlockWithThroughputProto proto;

// TODO: change this to use the unique name to get the MBB
// absl::StatusOr<BasicBlockProto> block_proto_or_status =
// BasicBlockProtoFromMachineCodeHex(machine_code_hex, base_address);
// if (!block_proto_or_status.ok()) return block_proto_or_status.status();
// *proto.mutable_basic_block() = std::move(block_proto_or_status).value();
absl::StatusOr<BasicBlockProto> block_proto_or_status =
BasicBlockProtoFromMBBName(BB_unique_name, base_address);
if (!block_proto_or_status.ok()) return block_proto_or_status.status();
*proto.mutable_basic_block() = std::move(block_proto_or_status).value();

double throughput_cycles = 0.0;
if (!absl::SimpleAtod(throughput_str, &throughput_cycles)) {
Expand All @@ -227,48 +250,35 @@ absl::StatusOr<bool> BHiveImporter::LoadMIRModule(std::string_view file_name){
name_to_mbb_.clear();

// create MIR Parser and read all MBB to the map based on their unique name
llvm::LLVMContext context;
llvm::SMDiagnostic diag;

// Set attributes on functions as loaded from MIR from command line arguments.
// auto setMIRFunctionAttributes = [&CPUStr, &FeaturesStr](Function &F) {
// llvm::codegen::setFunctionAttributes(CPUStr, FeaturesStr, F);
// };

std::unique_ptr<llvm::MIRParser> mir_parser = llvm::createMIRParserFromFile(file_name, diag, context);
if (!mir_parser) {
diag.print("test ", llvm::WithColor::error(llvm::errs(), "test"));
mir_parser_ = llvm::createMIRParserFromFile(file_name, diag, llvm_context_);
if (!mir_parser_) {
return absl::InvalidArgumentError(
absl::StrCat("Could not create MIR parser for file ", file_name));
}

// Parse the LLVM IR module (if any)
std::unique_ptr<llvm::Module> mir_module = mir_parser->parseIRModule();
if (!mir_module) {
mir_module_ = mir_parser_->parseIRModule();
if (!mir_module_) {
// Handle error
return absl::InvalidArgumentError(
absl::StrCat("Could not parse MIR module for file ", file_name));
}

// Prepare MachineModuleInfo
auto *llvmTargetMachine = dynamic_cast<const llvm::LLVMTargetMachine*>(&target_machine_);
if (llvmTargetMachine == nullptr) {
return absl::InvalidArgumentError(
absl::StrCat("Could not cast target machine for file ", file_name));
}
llvm::MachineModuleInfo MMI(llvmTargetMachine);
MMI_.initialize();

// Parse the MachineFunctions and add them to MMI
if (mir_parser->parseMachineFunctions(*mir_module, MMI)) {
if (mir_parser_->parseMachineFunctions(*mir_module_, MMI_)) {
// Handle error
return absl::InvalidArgumentError(
absl::StrCat("Could not parse MachineFunctions for file ", file_name));
}

// Now iterate over the MachineFunctions and their MachineBasicBlocks
for (auto &F : *mir_module) {
for (auto &F : *mir_module_) {
if (F.isDeclaration()) continue;
llvm::MachineFunction &MF = MMI.getOrCreateMachineFunction(F);
llvm::MachineFunction &MF = MMI_.getOrCreateMachineFunction(F);
for (auto &MBB : MF) {
// assert name is unique
if (name_to_mbb_.find(MBB.getName()) != name_to_mbb_.end()) {
Expand All @@ -277,10 +287,6 @@ absl::StatusOr<bool> BHiveImporter::LoadMIRModule(std::string_view file_name){
} else {
name_to_mbb_[MBB.getName()] = &MBB;
}
// // Pretty print the machine block with its name
// llvm::outs() << "MachineBasicBlock: " << MBB.getName() << "\n";
// MBB.print(llvm::outs());
// llvm::outs() << "\n";
}
}

Expand Down
8 changes: 8 additions & 0 deletions gematria/datasets/bhive_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "llvm/CodeGen/MIRParser/MIRParser.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/CodeGen/MachineModuleInfo.h"

namespace gematria {

Expand Down Expand Up @@ -63,6 +64,9 @@ class BHiveImporter {
// corresponds to a three-byte sequence {0xAA, 0xBB, 0x11}.
absl::StatusOr<BasicBlockProto> BasicBlockProtoFromMachineCodeHex(
std::string_view machine_code_hex, uint64_t base_address = 0);

absl::StatusOr<BasicBlockProto> BasicBlockProtoFromMBBName(
std::string_view MBB_name, uint64_t base_address = 0);

// Parses a basic block with throughput from one BHive CSV line. Expects that
// the line has the format "{machine_code},{throughput}" where {machine_code}
Expand Down Expand Up @@ -101,6 +105,10 @@ class BHiveImporter {
std::unique_ptr<llvm::MCDisassembler> disassembler_;
std::unique_ptr<llvm::MCInstPrinter> mc_inst_printer_;
llvm::DenseMap<llvm::StringRef, llvm::MachineBasicBlock*> name_to_mbb_;
llvm::LLVMContext llvm_context_;
std::unique_ptr<llvm::Module> mir_module_;
llvm::MachineModuleInfo MMI_;
std::unique_ptr<llvm::MIRParser> mir_parser_;
};

} // namespace gematria
Expand Down
7 changes: 1 addition & 6 deletions gematria/datasets/bhive_importer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,15 +226,10 @@ TEST_F(BHiveImporterTest, NonStandardColumns) {
})pb")));
}

TEST_F(BHiveImporterTest, LoadMIRModule) {
EXPECT_THAT(x86_bhive_importer_->LoadMIRModule("/u9/z277zhu/research/gematria/sample_dataset/data.mir"),
IsOk());
}

TEST_F(BHiveImporterTest, MIRDatasetBasicTest) {
EXPECT_THAT(x86_bhive_importer_->LoadMIRModule("/u9/z277zhu/research/gematria/sample_dataset/data.mir"),
IsOk());
EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine(kSourceName, "a,b,BB_10,0", 2,
EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine(kSourceName, "a,b,BB_13,2.37", 2,
3, kScaling),
IsOk());
}
Expand Down

0 comments on commit 790d19f

Please sign in to comment.