diff --git a/gematria/datasets/bhive_importer.cc b/gematria/datasets/bhive_importer.cc index f957da6c..122e14fe 100644 --- a/gematria/datasets/bhive_importer.cc +++ b/gematria/datasets/bhive_importer.cc @@ -81,14 +81,14 @@ BHiveImporter::BHiveImporter(const Canonicalizer* canonicalizer) *target_machine_.getMCAsmInfo(), *target_machine_.getMCInstrInfo(), *target_machine_.getMCRegisterInfo())), MMI_(dynamic_cast(&target_machine_)) { - const llvm::MCRegisterInfo& MRI = *target_machine_.getMCRegisterInfo(); - for (llvm::MCPhysReg I = 1, E = MRI.getNumRegs(); I != E; ++I) { - // Append register definition line. - llvm::StringRef reg_name = MRI.getName(I); - name_to_reg_[reg_name.str()] = I; - } - prettyPrintName2Reg(); + const llvm::MCRegisterInfo& MRI = *target_machine_.getMCRegisterInfo(); + for (llvm::MCPhysReg I = 1, E = MRI.getNumRegs(); I != E; ++I) { + // Append register definition line. + llvm::StringRef reg_name = MRI.getName(I); + name_to_reg_[reg_name.str()] = I; } + prettyPrintName2Reg(); +} absl::StatusOr BHiveImporter::BasicBlockProtoFromMachineCode( llvm::ArrayRef machine_code, uint64_t base_address /*= 0*/) { @@ -236,6 +236,22 @@ absl::StatusOr BHiveImporter::ParseMIRCsvLine( absl::StatusOr block_proto_or_status = BasicBlockProtoFromMBBName(BB_unique_name, base_address); if (!block_proto_or_status.ok()) return block_proto_or_status.status(); + + llvm::StringRef MBB_name_ref(BB_unique_name.data(), BB_unique_name.size()); + // lookup the MBB in the map, if not, return error + if (name_to_mbb_.find(MBB_name_ref) == name_to_mbb_.end()) { + return absl::InvalidArgumentError( + absl::StrCat("Could not find MBB with name ", BB_unique_name)); + } + + llvm::MachineBasicBlock* MBB = name_to_mbb_[MBB_name_ref]; + std::string func_name = MBB->getParent()->getName().str(); + assert(func_to_live_intervals_.find(func_name) != + func_to_live_intervals_.end() && "Function not found in map"); + addInterferenceGraph(*block_proto_or_status, + func_to_live_intervals_[func_name], + func_to_live_intervals_[func_name] + .BBRangeList[std::string(BB_unique_name)]); *proto.mutable_basic_block() = std::move(block_proto_or_status).value(); double throughput_cycles = 0.0; @@ -330,12 +346,14 @@ void printMap( std::cerr << "Function Name: " << functionInfoPair.first << "\n"; // Print live range of register - for (auto& pairInfo : functionInfoPair.second.virtual_register_live_range_func) { + for (auto& pairInfo : + functionInfoPair.second.virtual_register_live_range_func) { LOG("Virtual Register Name: " << pairInfo.first); printRegLiveIntervals(pairInfo.second); } - for (auto& pairInfo : functionInfoPair.second.physical_register_live_range_func) { + for (auto& pairInfo : + functionInfoPair.second.physical_register_live_range_func) { LOG("Physical Register Name: " << pairInfo.first); printRegLiveIntervals(pairInfo.second); } @@ -351,6 +369,15 @@ void printMap( } } +static bool areIntersected(const BHiveImporter::BhiveLiveRange& range1, + const BHiveImporter::BhiveLiveRange& range2) { + // Check if one range starts after the other ends or vice versa. + if (range1.second <= range2.first || range2.second <= range1.first) { + return false; // No intersection. + } + return true; // Ranges are intersected. +} + absl::StatusOr BHiveImporter::InteferenceGraphParser( std::string_view file_name) { // Boilerplate for reading input @@ -408,15 +435,17 @@ absl::StatusOr BHiveImporter::InteferenceGraphParser( // Since LLVM do not support [] operator we need to find it first auto resultRegLiveIntervals = - (is_virtual) ? info->virtual_register_live_range_func.find(currentRegister) - :info->physical_register_live_range_func.find(currentRegister); + (is_virtual) + ? info->virtual_register_live_range_func.find(currentRegister) + : info->physical_register_live_range_func.find(currentRegister); // If you find the current register in the register_live_range_func, // you insert a BhiveLiveRange with {start, end} in the range list of // the find return If not, then you insert a new pair: {currentRegister, // RegLiveIntervals} - if (is_virtual){ - if (resultRegLiveIntervals != info->virtual_register_live_range_func.end()) { + if (is_virtual) { + if (resultRegLiveIntervals != + info->virtual_register_live_range_func.end()) { // If you find the register, then you insert a new range resultRegLiveIntervals->second.rangeList.push_back({start, end}); } else { @@ -429,7 +458,8 @@ absl::StatusOr BHiveImporter::InteferenceGraphParser( {currentRegister, newRegLiveIntervals}); } } else { - if (resultRegLiveIntervals != info->physical_register_live_range_func.end()) { + if (resultRegLiveIntervals != + info->physical_register_live_range_func.end()) { // If you find the register, then you insert a new range resultRegLiveIntervals->second.rangeList.push_back({start, end}); } else { @@ -482,49 +512,30 @@ absl::StatusOr BHiveImporter::InteferenceGraphParser( // Now we want to debug and print things inside the FunctionLiveIntervalMap printMap(func_to_live_intervals_); + return true; +} - // // This stores the information of the whole function - // std::vector FunctionInfoList; - - // // At this time, we already processed all information in the file - // // Now we want to construct the interference graph - // // We first create an object that represents inference graph in a basic - // block struct InferenceBB { - // std::map> adjacencyList; - // }; - - // // This is a vector that stores information of a BB in each function of the - // function list std::vector> AllFunction; - - // // We still need to find what is the name of each basic block - // for (FunctionInfo functionInfo : FunctionInfoList) { - - // std::vector functionAllBB; - - // // Consider a basic block at a time - // for (std::pair BBInformation : - // functionInfo.BBRangeList ) { - // // We create an object that stores adjacency list of a the inference - // graph of a single BB InferenceBB adjacencySingleBB; - - // // Now for each pair of register - // // First decide whether they are in this basic block or not - // // and then decide whether they intersect () - // for (RegLiveInterval Reg1 : functionInfo.register_live_range_func) { - // for (RegLiveInterval Reg2 : functionInfo.register_live_range_func) { - // if (intersect(Reg1, Reg2, BBInformation)) { - // adjacencySingleBB.adjacencyList[Reg1.name].push_back(Reg2.name); - // adjacencySingleBB.adjacencyList[Reg2.name].push_back(Reg1.name); - // } - // } - // } - - // // Now we add the adjacency of a single BB into the functionAllBB - // functionAllBB.push_back(adjacencySingleBB); - // } - - // // Add the inference graph of all BB in a function to the whole list - // AllFunction.push_back(functionAllBB); +absl::StatusOr BHiveImporter::addInterferenceGraph( + BasicBlockProto& bb_proto, + const BHiveImporter::FunctionLiveIntervalInfo& func_live_infos, + const BHiveImporter::BhiveLiveRange& bb_range) { + std::set live_virtual_registers; + // iterate over all operands in bb_proto, add virtual registers to + // live_virtual_registers + for (const auto& instruction : bb_proto.canonicalized_instructions()) { + for (const auto& operand : instruction.input_operands()) { + if (operand.operand_case() == + CanonicalizedOperandProto::kVirtualRegister) { + live_virtual_registers.insert(operand.virtual_register().name()); + } + } + for (const auto& operand : instruction.output_operands()) { + if (operand.operand_case() == + CanonicalizedOperandProto::kVirtualRegister) { + live_virtual_registers.insert(operand.virtual_register().name()); + } + } + } return true; } diff --git a/gematria/datasets/bhive_importer.h b/gematria/datasets/bhive_importer.h index 2423b7a1..f9c406a0 100644 --- a/gematria/datasets/bhive_importer.h +++ b/gematria/datasets/bhive_importer.h @@ -159,6 +159,8 @@ class BHiveImporter { // to take in machine instruction/ fucntion absl::StatusOr InteferenceGraphParser(std::string_view file_name); + absl::StatusOr addInterferenceGraph(BasicBlockProto& bb_proto, const FunctionLiveIntervalInfo& func_live_infos, const BhiveLiveRange& bb_range); + private: const Canonicalizer& canonicalizer_; const llvm::TargetMachine& target_machine_; diff --git a/gematria/datasets/bhive_importer_test.cc b/gematria/datasets/bhive_importer_test.cc index 3c471a01..c7eb95a7 100644 --- a/gematria/datasets/bhive_importer_test.cc +++ b/gematria/datasets/bhive_importer_test.cc @@ -226,24 +226,16 @@ TEST_F(BHiveImporterTest, NonStandardColumns) { })pb"))); } -TEST_F(BHiveImporterTest, MIRDatasetBasicTest) { - EXPECT_THAT(x86_bhive_importer_->LoadMIRModule("sample_dataset/data.mir"), - IsOk()); - EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine( - kSourceName, "a,b,BB_13,2.37", 2, 3, kScaling), - IsOk()); -} - TEST_F(BHiveImporterTest, MIRDatasetTest2) { EXPECT_THAT( x86_bhive_importer_->LoadMIRModule("sample_dataset/native_test.mir"), IsOk()); - EXPECT_THAT(x86_bhive_importer_->InteferenceGraphParser( - "sample_dataset/singleliveinfo"), - IsOk()); EXPECT_THAT( x86_bhive_importer_->InteferenceGraphParser("sample_dataset/liveinfo"), IsOk()); + EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine( + kSourceName, "a,b,BB_13,2.37", 2, 3, kScaling), + IsOk()); } } // namespace