Skip to content

Commit

Permalink
Added intersection API and framework for instrument proto
Browse files Browse the repository at this point in the history
  • Loading branch information
9Tempest committed Dec 5, 2023
1 parent 89ee21c commit 9867517
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 67 deletions.
123 changes: 67 additions & 56 deletions gematria/datasets/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ BHiveImporter::BHiveImporter(const Canonicalizer* canonicalizer)
*target_machine_.getMCAsmInfo(), *target_machine_.getMCInstrInfo(),
*target_machine_.getMCRegisterInfo())),
MMI_(dynamic_cast<const llvm::LLVMTargetMachine*>(&target_machine_)) {
const llvm::MCRegisterInfo& MRI = *target_machine_.getMCRegisterInfo();
for (llvm::MCPhysReg I = 1, E = MRI.getNumRegs(); I != E; ++I) {
// Append register definition line.
llvm::StringRef reg_name = MRI.getName(I);
name_to_reg_[reg_name.str()] = I;
}
prettyPrintName2Reg();
const llvm::MCRegisterInfo& MRI = *target_machine_.getMCRegisterInfo();
for (llvm::MCPhysReg I = 1, E = MRI.getNumRegs(); I != E; ++I) {
// Append register definition line.
llvm::StringRef reg_name = MRI.getName(I);
name_to_reg_[reg_name.str()] = I;
}
prettyPrintName2Reg();
}

absl::StatusOr<BasicBlockProto> BHiveImporter::BasicBlockProtoFromMachineCode(
llvm::ArrayRef<uint8_t> machine_code, uint64_t base_address /*= 0*/) {
Expand Down Expand Up @@ -236,6 +236,22 @@ absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseMIRCsvLine(
absl::StatusOr<BasicBlockProto> block_proto_or_status =
BasicBlockProtoFromMBBName(BB_unique_name, base_address);
if (!block_proto_or_status.ok()) return block_proto_or_status.status();

llvm::StringRef MBB_name_ref(BB_unique_name.data(), BB_unique_name.size());
// lookup the MBB in the map, if not, return error
if (name_to_mbb_.find(MBB_name_ref) == name_to_mbb_.end()) {
return absl::InvalidArgumentError(
absl::StrCat("Could not find MBB with name ", BB_unique_name));
}

llvm::MachineBasicBlock* MBB = name_to_mbb_[MBB_name_ref];
std::string func_name = MBB->getParent()->getName().str();
assert(func_to_live_intervals_.find(func_name) !=
func_to_live_intervals_.end() && "Function not found in map");
addInterferenceGraph(*block_proto_or_status,
func_to_live_intervals_[func_name],
func_to_live_intervals_[func_name]
.BBRangeList[std::string(BB_unique_name)]);
*proto.mutable_basic_block() = std::move(block_proto_or_status).value();

double throughput_cycles = 0.0;
Expand Down Expand Up @@ -330,12 +346,14 @@ void printMap(
std::cerr << "Function Name: " << functionInfoPair.first << "\n";

// Print live range of register
for (auto& pairInfo : functionInfoPair.second.virtual_register_live_range_func) {
for (auto& pairInfo :
functionInfoPair.second.virtual_register_live_range_func) {
LOG("Virtual Register Name: " << pairInfo.first);
printRegLiveIntervals(pairInfo.second);
}

for (auto& pairInfo : functionInfoPair.second.physical_register_live_range_func) {
for (auto& pairInfo :
functionInfoPair.second.physical_register_live_range_func) {
LOG("Physical Register Name: " << pairInfo.first);
printRegLiveIntervals(pairInfo.second);
}
Expand All @@ -351,6 +369,15 @@ void printMap(
}
}

static bool areIntersected(const BHiveImporter::BhiveLiveRange& range1,
const BHiveImporter::BhiveLiveRange& range2) {
// Check if one range starts after the other ends or vice versa.
if (range1.second <= range2.first || range2.second <= range1.first) {
return false; // No intersection.
}
return true; // Ranges are intersected.
}

absl::StatusOr<bool> BHiveImporter::InteferenceGraphParser(
std::string_view file_name) {
// Boilerplate for reading input
Expand Down Expand Up @@ -408,15 +435,17 @@ absl::StatusOr<bool> BHiveImporter::InteferenceGraphParser(

// Since LLVM do not support [] operator we need to find it first
auto resultRegLiveIntervals =
(is_virtual) ? info->virtual_register_live_range_func.find(currentRegister)
:info->physical_register_live_range_func.find(currentRegister);
(is_virtual)
? info->virtual_register_live_range_func.find(currentRegister)
: info->physical_register_live_range_func.find(currentRegister);

// If you find the current register in the register_live_range_func,
// you insert a BhiveLiveRange with {start, end} in the range list of
// the find return If not, then you insert a new pair: {currentRegister,
// RegLiveIntervals}
if (is_virtual){
if (resultRegLiveIntervals != info->virtual_register_live_range_func.end()) {
if (is_virtual) {
if (resultRegLiveIntervals !=
info->virtual_register_live_range_func.end()) {
// If you find the register, then you insert a new range
resultRegLiveIntervals->second.rangeList.push_back({start, end});
} else {
Expand All @@ -429,7 +458,8 @@ absl::StatusOr<bool> BHiveImporter::InteferenceGraphParser(
{currentRegister, newRegLiveIntervals});
}
} else {
if (resultRegLiveIntervals != info->physical_register_live_range_func.end()) {
if (resultRegLiveIntervals !=
info->physical_register_live_range_func.end()) {
// If you find the register, then you insert a new range
resultRegLiveIntervals->second.rangeList.push_back({start, end});
} else {
Expand Down Expand Up @@ -482,49 +512,30 @@ absl::StatusOr<bool> BHiveImporter::InteferenceGraphParser(

// Now we want to debug and print things inside the FunctionLiveIntervalMap
printMap(func_to_live_intervals_);
return true;
}

// // This stores the information of the whole function
// std::vector<FunctionInfo> FunctionInfoList;

// // At this time, we already processed all information in the file
// // Now we want to construct the interference graph
// // We first create an object that represents inference graph in a basic
// block struct InferenceBB {
// std::map<std::string, std::vector<std::string>> adjacencyList;
// };

// // This is a vector that stores information of a BB in each function of the
// function list std::vector<std::vector<InferenceBB>> AllFunction;

// // We still need to find what is the name of each basic block
// for (FunctionInfo functionInfo : FunctionInfoList) {

// std::vector<InferenceBB> functionAllBB;

// // Consider a basic block at a time
// for (std::pair<std::string, std::string> BBInformation :
// functionInfo.BBRangeList ) {
// // We create an object that stores adjacency list of a the inference
// graph of a single BB InferenceBB adjacencySingleBB;

// // Now for each pair of register
// // First decide whether they are in this basic block or not
// // and then decide whether they intersect ()
// for (RegLiveInterval Reg1 : functionInfo.register_live_range_func) {
// for (RegLiveInterval Reg2 : functionInfo.register_live_range_func) {
// if (intersect(Reg1, Reg2, BBInformation)) {
// adjacencySingleBB.adjacencyList[Reg1.name].push_back(Reg2.name);
// adjacencySingleBB.adjacencyList[Reg2.name].push_back(Reg1.name);
// }
// }
// }

// // Now we add the adjacency of a single BB into the functionAllBB
// functionAllBB.push_back(adjacencySingleBB);
// }

// // Add the inference graph of all BB in a function to the whole list
// AllFunction.push_back(functionAllBB);
absl::StatusOr<bool> BHiveImporter::addInterferenceGraph(
BasicBlockProto& bb_proto,
const BHiveImporter::FunctionLiveIntervalInfo& func_live_infos,
const BHiveImporter::BhiveLiveRange& bb_range) {
std::set<std::string> live_virtual_registers;
// iterate over all operands in bb_proto, add virtual registers to
// live_virtual_registers
for (const auto& instruction : bb_proto.canonicalized_instructions()) {
for (const auto& operand : instruction.input_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
live_virtual_registers.insert(operand.virtual_register().name());
}
}
for (const auto& operand : instruction.output_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
live_virtual_registers.insert(operand.virtual_register().name());
}
}
}
return true;
}

Expand Down
2 changes: 2 additions & 0 deletions gematria/datasets/bhive_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ class BHiveImporter {
// to take in machine instruction/ fucntion
absl::StatusOr<bool> InteferenceGraphParser(std::string_view file_name);

absl::StatusOr<bool> addInterferenceGraph(BasicBlockProto& bb_proto, const FunctionLiveIntervalInfo& func_live_infos, const BhiveLiveRange& bb_range);

private:
const Canonicalizer& canonicalizer_;
const llvm::TargetMachine& target_machine_;
Expand Down
14 changes: 3 additions & 11 deletions gematria/datasets/bhive_importer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,24 +226,16 @@ TEST_F(BHiveImporterTest, NonStandardColumns) {
})pb")));
}

TEST_F(BHiveImporterTest, MIRDatasetBasicTest) {
EXPECT_THAT(x86_bhive_importer_->LoadMIRModule("sample_dataset/data.mir"),
IsOk());
EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine(
kSourceName, "a,b,BB_13,2.37", 2, 3, kScaling),
IsOk());
}

TEST_F(BHiveImporterTest, MIRDatasetTest2) {
EXPECT_THAT(
x86_bhive_importer_->LoadMIRModule("sample_dataset/native_test.mir"),
IsOk());
EXPECT_THAT(x86_bhive_importer_->InteferenceGraphParser(
"sample_dataset/singleliveinfo"),
IsOk());
EXPECT_THAT(
x86_bhive_importer_->InteferenceGraphParser("sample_dataset/liveinfo"),
IsOk());
EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine(
kSourceName, "a,b,BB_13,2.37", 2, 3, kScaling),
IsOk());
}

} // namespace
Expand Down

0 comments on commit 9867517

Please sign in to comment.