Skip to content

Commit

Permalink
Added bhive importer support for physical register interference info
Browse files Browse the repository at this point in the history
  • Loading branch information
9Tempest committed Dec 5, 2023
1 parent b61cd3c commit 29cc7a7
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 75 deletions.
171 changes: 102 additions & 69 deletions gematria/datasets/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,17 @@ BHiveImporter::BHiveImporter(const Canonicalizer* canonicalizer)
// Append register definition line.
llvm::StringRef reg_name = MRI.getName(I);
name_to_reg_[reg_name.str()] = I;
// push itself to its own superreg2subreg_ list
superreg2subreg_[reg_name.str()].push_back(reg_name.str());
for (auto SuperReg : MRI.superregs(I)) {
if (MRI.isSubRegister(SuperReg, I)) {
llvm::StringRef super_reg_name = MRI.getName(SuperReg);
superreg2subreg_[super_reg_name.str()].push_back(reg_name.str());
}
}
}
// prettyPrintName2Reg();
prettyPrintSuperReg2SubReg();
}

absl::StatusOr<BasicBlockProto> BHiveImporter::BasicBlockProtoFromMachineCode(
Expand Down Expand Up @@ -247,7 +256,8 @@ absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseMIRCsvLine(
llvm::MachineBasicBlock* MBB = name_to_mbb_[MBB_name_ref];
std::string func_name = MBB->getParent()->getName().str();
assert(func_to_live_intervals_.find(func_name) !=
func_to_live_intervals_.end() && "Function not found in map");
func_to_live_intervals_.end() &&
"Function not found in map");
addInterferenceGraph(*block_proto_or_status,
func_to_live_intervals_[func_name],
func_to_live_intervals_[func_name]
Expand Down Expand Up @@ -519,98 +529,121 @@ static bool checkRegIntersectionsWithBBRange(
const BHiveImporter::RegLiveIntervals& reg_live_interval1,
const BHiveImporter::RegLiveIntervals& reg_live_interval2,
const BHiveImporter::BhiveLiveRange& bb_range) {
const BHiveImporter::BhiveLiveRange* range1HitsBB = nullptr;
for (auto& interval : reg_live_interval1.rangeList) {
if (areIntersected(interval, bb_range)) {
range1HitsBB = &interval;
}
const BHiveImporter::BhiveLiveRange* range1HitsBB = nullptr;
for (auto& interval : reg_live_interval1.rangeList) {
if (areIntersected(interval, bb_range)) {
range1HitsBB = &interval;
}
if (!range1HitsBB) {
return false;
}
for (auto& interval : reg_live_interval2.rangeList) {
if (areIntersected(interval, bb_range)) {
if (areIntersected(*range1HitsBB, interval)) {
return true;
}
}
if (!range1HitsBB) {
return false;
}
for (auto& interval : reg_live_interval2.rangeList) {
if (areIntersected(interval, bb_range)) {
if (areIntersected(*range1HitsBB, interval)) {
return true;
}
}
}
return false;
}

// void ToRepeatedPtrField(
// const std::vector<InstructionOperand>& operands,
// google::protobuf::RepeatedPtrField<CanonicalizedOperandProto>*
// repeated_field) {
// repeated_field->Reserve(operands.size());
// std::transform(operands.begin(), operands.end(),
// google::protobuf::RepeatedFieldBackInserter(repeated_field));
// }

void BHiveImporter::addInterferenceGraph(
BasicBlockProto& bb_proto,
BHiveImporter::FunctionLiveIntervalInfo& func_live_infos,
BHiveImporter::BhiveLiveRange& bb_range) {
std::set<std::string> live_virtual_registers;
std::set<std::string> live_physical_registers;

// helper function to update live_virtual_registers and
// live_physical_registers
auto update_live_regs = [&](const CanonicalizedOperandProto& operand) {
if (operand.operand_case() == CanonicalizedOperandProto::kVirtualRegister) {
live_virtual_registers.insert(operand.virtual_register().name());
} else if (operand.operand_case() ==
CanonicalizedOperandProto::kRegisterName) {
live_physical_registers.insert(operand.register_name());
}
};

auto add_interference = [&](CanonicalizedOperandProto& operand) {
if (operand.operand_case() == CanonicalizedOperandProto::kVirtualRegister) {
// add interference from other virtual registers to current operand
for (auto vReg : live_virtual_registers) {
if (vReg == operand.virtual_register().name()) continue;
assert(func_live_infos.virtual_register_live_range_func.find(vReg) !=
func_live_infos.virtual_register_live_range_func.end() &&
"Virtual register not found in map");
// If the live range of the two registers intersect, then add
// interference to proto
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func
[operand.virtual_register().name()],
func_live_infos.virtual_register_live_range_func[vReg],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(vReg));
}
}
// add interference from physical registers to current operand
for (auto pReg : live_physical_registers) {
auto subRegs = superreg2subreg_[pReg];
// if there's one subReg of Preg that has interference with current
// operand then add interference to proto
for (auto subReg : subRegs) {
if (func_live_infos.physical_register_live_range_func.find(subReg) ==
func_live_infos.physical_register_live_range_func.end())
continue;
// pretty print live range of subRegs
LOG("Live range of subReg: " << subReg);
for (auto& range :
func_live_infos.physical_register_live_range_func[subReg]
.rangeList) {
LOG(" " << range.first << ", " << range.second);
}
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func
[operand.virtual_register().name()],
func_live_infos.physical_register_live_range_func[subReg],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(pReg));
break;
}
}
}
}
};
// iterate over all operands in bb_proto, add virtual registers to
// live_virtual_registers
for (const auto& instruction : bb_proto.canonicalized_instructions()) {
for (const auto& operand : instruction.input_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
live_virtual_registers.insert(operand.virtual_register().name());
}
update_live_regs(operand);
}
for (const auto& operand : instruction.implicit_input_operands()) {
update_live_regs(operand);
}
for (const auto& operand : instruction.output_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
live_virtual_registers.insert(operand.virtual_register().name());
}
update_live_regs(operand);
}
for (const auto& operand : instruction.implicit_output_operands()) {
update_live_regs(operand);
}
}

// pretty print physical registers
LOG("Physical Registers: ");
for (auto& reg : live_physical_registers) {
LOG("Physical Register: " << reg);
}

// Iterate over all operands in bb_proto, add interference registers to each operand

// Iterate over all operands in bb_proto, add interference registers to each
// operand
for (auto& instruction : *bb_proto.mutable_canonicalized_instructions()) {
// LOG("before: " << instruction.DebugString());
for (auto& operand : *instruction.mutable_input_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
for (auto vRegs : live_virtual_registers){
if (vRegs == operand.virtual_register().name()) continue;
assert(func_live_infos.virtual_register_live_range_func.find(vRegs) !=
func_live_infos.virtual_register_live_range_func.end() &&
"Virtual register not found in map");
// If the live range of the two registers intersect, then add
// interference to proto
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func[
operand.virtual_register().name()],
func_live_infos.virtual_register_live_range_func[vRegs],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(vRegs));
}
}
}
add_interference(operand);
}
for (auto& operand : *instruction.mutable_output_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
for (auto vRegs : live_virtual_registers){
if (vRegs == operand.virtual_register().name()) continue;
assert(func_live_infos.virtual_register_live_range_func.find(vRegs) !=
func_live_infos.virtual_register_live_range_func.end() &&
"Virtual register not found in map");
// If the live range of the two registers intersect, then add
// interference to proto
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func[
operand.virtual_register().name()],
func_live_infos.virtual_register_live_range_func[vRegs],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(vRegs));
}
}
}
add_interference(operand);
}
// LOG("after: " << instruction.DebugString());
}
Expand Down
22 changes: 19 additions & 3 deletions gematria/datasets/bhive_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ class BHiveImporter {

// A struct that store all intervals in a function as well as ranges of BB
struct FunctionLiveIntervalInfo {
std::unordered_map<std::string, RegLiveIntervals> virtual_register_live_range_func;
std::unordered_map<std::string, RegLiveIntervals> physical_register_live_range_func;
std::unordered_map<std::string, RegLiveIntervals>
virtual_register_live_range_func;
std::unordered_map<std::string, RegLiveIntervals>
physical_register_live_range_func;
std::unordered_map<std::string, BhiveLiveRange> BBRangeList;
};

Expand All @@ -150,6 +152,17 @@ class BHiveImporter {
}
}

// pretty print superreg2subreg_
void prettyPrintSuperReg2SubReg() {
LOG("SuperReg2SubReg: ");
for (auto& [superreg, subreg] : superreg2subreg_) {
LOG(superreg << ": ");
for (auto& sub : subreg) {
LOG("\t" << sub);
}
}
}

// Now we are able to obtain the live range for each register
// We want to for each pair of regsiter, find out if their live range overlap
// Edge case 1: one live range may have multiple live ranges,
Expand All @@ -159,7 +172,9 @@ class BHiveImporter {
// to take in machine instruction/ fucntion
absl::StatusOr<bool> InteferenceGraphParser(std::string_view file_name);

void addInterferenceGraph(BasicBlockProto& bb_proto, FunctionLiveIntervalInfo& func_live_infos, BhiveLiveRange& bb_range);
void addInterferenceGraph(BasicBlockProto& bb_proto,
FunctionLiveIntervalInfo& func_live_infos,
BhiveLiveRange& bb_range);

private:
const Canonicalizer& canonicalizer_;
Expand All @@ -171,6 +186,7 @@ class BHiveImporter {
std::unordered_map<std::string, FunctionLiveIntervalInfo>
func_to_live_intervals_;
std::unordered_map<std::string, llvm::MCPhysReg> name_to_reg_;
std::unordered_map<std::string, std::vector<std::string>> superreg2subreg_;
llvm::LLVMContext llvm_context_;
std::unique_ptr<llvm::Module> mir_module_;
llvm::MachineModuleInfo MMI_;
Expand Down
4 changes: 2 additions & 2 deletions gematria/datasets/bhive_importer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ TEST_F(BHiveImporterTest, MIRDatasetTest2) {
x86_bhive_importer_->InteferenceGraphParser("sample_dataset/liveinfo"),
IsOk());
EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine(
kSourceName, "a,b,BB_21,2.37", 2, 3, kScaling),
IsOk());
kSourceName, "a,b,BB_21,2.37", 2, 3, kScaling),
IsOk());
}

} // namespace
Expand Down
2 changes: 1 addition & 1 deletion gematria/llvm/canonicalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ void X86Canonicalizer::AddOperand(const llvm::MachineInstr& mi, int operand_inde
GetRegisterNameOrEmpty(
mi.getOperand(operand_index + llvm::X86::AddrBaseReg), base_register, tmp_size);
} else if (mi.getOperand(operand_index + llvm::X86::AddrBaseReg).isFI()){
base_register = "rbp";
base_register = "RBP";
} else {
assert(false && "unsupported base register type");
LOG(mi);
Expand Down

0 comments on commit 29cc7a7

Please sign in to comment.