From 790d19f230e06c368bf380d41423e321a056f1ad Mon Sep 17 00:00:00 2001 From: z277zhu Date: Sun, 5 Nov 2023 13:45:01 -0500 Subject: [PATCH] Fixed module storage bug: expand life time of llvm context and module from LoadMIRModule --- gematria/datasets/bhive_importer.cc | 88 +++++++++++++----------- gematria/datasets/bhive_importer.h | 8 +++ gematria/datasets/bhive_importer_test.cc | 7 +- 3 files changed, 56 insertions(+), 47 deletions(-) diff --git a/gematria/datasets/bhive_importer.cc b/gematria/datasets/bhive_importer.cc index 04d494de..6f2bad71 100644 --- a/gematria/datasets/bhive_importer.cc +++ b/gematria/datasets/bhive_importer.cc @@ -38,12 +38,20 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/Error.h" -#include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/WithColor.h" #include "llvm/Support/raw_ostream.h" +#define DEBUG + +#ifdef DEBUG +#define LOG(X) \ + llvm::errs() << X << "\n" +#else +#define LOG(X) +#endif + namespace gematria { namespace { @@ -66,7 +74,8 @@ BHiveImporter::BHiveImporter(const Canonicalizer* canonicalizer) mc_inst_printer_(target_machine_.getTarget().createMCInstPrinter( target_machine_.getTargetTriple(), kDefaultSyntax, *target_machine_.getMCAsmInfo(), *target_machine_.getMCInstrInfo(), - *target_machine_.getMCRegisterInfo())) {} + *target_machine_.getMCRegisterInfo())), + MMI_(dynamic_cast(&target_machine_)) {} absl::StatusOr BHiveImporter::BasicBlockProtoFromMachineCode( llvm::ArrayRef machine_code, uint64_t base_address /*= 0*/) { @@ -152,15 +161,30 @@ absl::StatusOr BHiveImporter::ParseBHiveCsvLine( absl::StatusOr BHiveImporter::BasicBlockProtoFromMBBName( std::string_view MBB_name, uint64_t base_address /*= 0*/) { BasicBlockProto basic_block_proto; - // llvm::Expected> instructions = - // DisassembleAllInstructions(*disassembler_, - // *target_machine_.getMCInstrInfo(), - // *target_machine_.getMCRegisterInfo(), - // *target_machine_.getMCSubtargetInfo(), - // *mc_inst_printer_, base_address, machine_code); - // if (llvm::Error error = instructions.takeError()) { - // return LlvmErrorToStatus(std::move(error)); - // } + // convert MBB_name to llvm::StringRef + llvm::StringRef MBB_name_ref(MBB_name.data(), MBB_name.size()); + + // lookup the MBB in the map, if not, return error + if (name_to_mbb_.find(MBB_name_ref) == name_to_mbb_.end()) { + LOG("Cannot find MBB, using key " << MBB_name); + return absl::InvalidArgumentError( + absl::StrCat("Could not find MBB with name ", MBB_name)); + } + + llvm::MachineBasicBlock* MBB = name_to_mbb_[MBB_name_ref]; + LOG("MBB is " << *MBB); + for (llvm::MachineInstr& MI : *MBB){ + // if MI is a control instruction(ret,branch,jmp), skip it + if (MI.isInlineAsm() || MI.isTerminator()) { + LOG("MI is a control instruction, skipping it " << MI); + continue; + } + + // Assert MI cannot be a CALL instruction + assert(!MI.isCall() && "MI is a CALL instruction, bad dataset"); + canonicalizer_.InstructionFromMachineInstr(MI); + // TODO: Add this to the basic block proto + } // for (DisassembledInstruction& instruction : *instructions) { // NOT VERY IMPORTANT THESE 3 LINES @@ -202,11 +226,10 @@ absl::StatusOr BHiveImporter::ParseMIRCsvLine( BasicBlockWithThroughputProto proto; - // TODO: change this to use the unique name to get the MBB - // absl::StatusOr block_proto_or_status = - // BasicBlockProtoFromMachineCodeHex(machine_code_hex, base_address); - // if (!block_proto_or_status.ok()) return block_proto_or_status.status(); - // *proto.mutable_basic_block() = std::move(block_proto_or_status).value(); + absl::StatusOr block_proto_or_status = + BasicBlockProtoFromMBBName(BB_unique_name, base_address); + if (!block_proto_or_status.ok()) return block_proto_or_status.status(); + *proto.mutable_basic_block() = std::move(block_proto_or_status).value(); double throughput_cycles = 0.0; if (!absl::SimpleAtod(throughput_str, &throughput_cycles)) { @@ -227,48 +250,35 @@ absl::StatusOr BHiveImporter::LoadMIRModule(std::string_view file_name){ name_to_mbb_.clear(); // create MIR Parser and read all MBB to the map based on their unique name - llvm::LLVMContext context; llvm::SMDiagnostic diag; - // Set attributes on functions as loaded from MIR from command line arguments. - // auto setMIRFunctionAttributes = [&CPUStr, &FeaturesStr](Function &F) { - // llvm::codegen::setFunctionAttributes(CPUStr, FeaturesStr, F); - // }; - - std::unique_ptr mir_parser = llvm::createMIRParserFromFile(file_name, diag, context); - if (!mir_parser) { - diag.print("test ", llvm::WithColor::error(llvm::errs(), "test")); + mir_parser_ = llvm::createMIRParserFromFile(file_name, diag, llvm_context_); + if (!mir_parser_) { return absl::InvalidArgumentError( absl::StrCat("Could not create MIR parser for file ", file_name)); } // Parse the LLVM IR module (if any) - std::unique_ptr mir_module = mir_parser->parseIRModule(); - if (!mir_module) { + mir_module_ = mir_parser_->parseIRModule(); + if (!mir_module_) { // Handle error return absl::InvalidArgumentError( absl::StrCat("Could not parse MIR module for file ", file_name)); } - // Prepare MachineModuleInfo - auto *llvmTargetMachine = dynamic_cast(&target_machine_); - if (llvmTargetMachine == nullptr) { - return absl::InvalidArgumentError( - absl::StrCat("Could not cast target machine for file ", file_name)); - } - llvm::MachineModuleInfo MMI(llvmTargetMachine); + MMI_.initialize(); // Parse the MachineFunctions and add them to MMI - if (mir_parser->parseMachineFunctions(*mir_module, MMI)) { + if (mir_parser_->parseMachineFunctions(*mir_module_, MMI_)) { // Handle error return absl::InvalidArgumentError( absl::StrCat("Could not parse MachineFunctions for file ", file_name)); } // Now iterate over the MachineFunctions and their MachineBasicBlocks - for (auto &F : *mir_module) { + for (auto &F : *mir_module_) { if (F.isDeclaration()) continue; - llvm::MachineFunction &MF = MMI.getOrCreateMachineFunction(F); + llvm::MachineFunction &MF = MMI_.getOrCreateMachineFunction(F); for (auto &MBB : MF) { // assert name is unique if (name_to_mbb_.find(MBB.getName()) != name_to_mbb_.end()) { @@ -277,10 +287,6 @@ absl::StatusOr BHiveImporter::LoadMIRModule(std::string_view file_name){ } else { name_to_mbb_[MBB.getName()] = &MBB; } - // // Pretty print the machine block with its name - // llvm::outs() << "MachineBasicBlock: " << MBB.getName() << "\n"; - // MBB.print(llvm::outs()); - // llvm::outs() << "\n"; } } diff --git a/gematria/datasets/bhive_importer.h b/gematria/datasets/bhive_importer.h index 10512b78..f385ab91 100644 --- a/gematria/datasets/bhive_importer.h +++ b/gematria/datasets/bhive_importer.h @@ -34,6 +34,7 @@ #include "llvm/CodeGen/MIRParser/MIRParser.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/CodeGen/MachineModuleInfo.h" namespace gematria { @@ -63,6 +64,9 @@ class BHiveImporter { // corresponds to a three-byte sequence {0xAA, 0xBB, 0x11}. absl::StatusOr BasicBlockProtoFromMachineCodeHex( std::string_view machine_code_hex, uint64_t base_address = 0); + + absl::StatusOr BasicBlockProtoFromMBBName( + std::string_view MBB_name, uint64_t base_address = 0); // Parses a basic block with throughput from one BHive CSV line. Expects that // the line has the format "{machine_code},{throughput}" where {machine_code} @@ -101,6 +105,10 @@ class BHiveImporter { std::unique_ptr disassembler_; std::unique_ptr mc_inst_printer_; llvm::DenseMap name_to_mbb_; + llvm::LLVMContext llvm_context_; + std::unique_ptr mir_module_; + llvm::MachineModuleInfo MMI_; + std::unique_ptr mir_parser_; }; } // namespace gematria diff --git a/gematria/datasets/bhive_importer_test.cc b/gematria/datasets/bhive_importer_test.cc index 562cc04f..1094aefe 100644 --- a/gematria/datasets/bhive_importer_test.cc +++ b/gematria/datasets/bhive_importer_test.cc @@ -226,15 +226,10 @@ TEST_F(BHiveImporterTest, NonStandardColumns) { })pb"))); } -TEST_F(BHiveImporterTest, LoadMIRModule) { - EXPECT_THAT(x86_bhive_importer_->LoadMIRModule("/u9/z277zhu/research/gematria/sample_dataset/data.mir"), - IsOk()); -} - TEST_F(BHiveImporterTest, MIRDatasetBasicTest) { EXPECT_THAT(x86_bhive_importer_->LoadMIRModule("/u9/z277zhu/research/gematria/sample_dataset/data.mir"), IsOk()); - EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine(kSourceName, "a,b,BB_10,0", 2, + EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine(kSourceName, "a,b,BB_13,2.37", 2, 3, kScaling), IsOk()); }