Skip to content

Commit

Permalink
Merge pull request #8 from GranLte/lukezhuz/phys_interference_info
Browse files Browse the repository at this point in the history
Added bhive importer support for physical register interference info
  • Loading branch information
9Tempest authored Dec 5, 2023
2 parents b61cd3c + 8f13408 commit 981cdd3
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 88 deletions.
3 changes: 2 additions & 1 deletion gematria/basic_block/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,12 @@ bool InstructionOperand::operator==(const InstructionOperand& other) const {
}

InstructionOperand InstructionOperand::VirtualRegister(
const std::string register_name, size_t size) {
const std::string register_name, size_t size, const std::vector<std::string>& interfered_registers) {
InstructionOperand result;
result.type_ = OperandType::kVirtualRegister;
result.register_name_ = std::move(register_name);
result.size_ = size;
result.interfered_registers_ = std::move(interfered_registers);
return result;
}

Expand Down
8 changes: 7 additions & 1 deletion gematria/basic_block/basic_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class InstructionOperand {

// The operands must be created through one of the factory functions.
static InstructionOperand VirtualRegister(std::string register_name,
size_t size);
size_t size, const std::vector<std::string>& interfered_registers);
static InstructionOperand Register(std::string register_name);
static InstructionOperand ImmediateValue(uint64_t immediate_value);
static InstructionOperand FpImmediateValue(double fp_immediate_value);
Expand All @@ -169,6 +169,11 @@ class InstructionOperand {
// Returns the list of tokens representing this instruction.
std::vector<std::string> AsTokenList() const;

std::vector<std::string> getInterferedRegisters() const {
assert(type_ == OperandType::kVirtualRegister);
return interfered_registers_;
}

// Returns a human-readable representation of the operand.
//
// This method implements the __str__() and __repr__() methods in the Python
Expand Down Expand Up @@ -225,6 +230,7 @@ class InstructionOperand {
double fp_immediate_value_ = 0.0;
AddressTuple address_;
int alias_group_id_ = 0;
std::vector<std::string> interfered_registers_;
};

std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand);
Expand Down
16 changes: 13 additions & 3 deletions gematria/basic_block/basic_block_protos.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@

namespace gematria {

namespace {
std::vector<std::string> ToVector(
const google::protobuf::RepeatedPtrField<std::string>& protos) {
return std::vector<std::string>(protos.begin(), protos.end());
}
}

AddressTuple AddressTupleFromProto(
const CanonicalizedOperandProto::AddressTuple& proto) {
return AddressTuple(
Expand Down Expand Up @@ -64,8 +71,12 @@ InstructionOperand InstructionOperandFromProto(
return InstructionOperand::MemoryLocation(
proto.memory().alias_group_id());
case CanonicalizedOperandProto::kVirtualRegister:
return InstructionOperand::VirtualRegister(
proto.virtual_register().name(), proto.virtual_register().size());
{
std::vector<std::string> interfered_registers = ToVector(proto.intefered_register());
return InstructionOperand::VirtualRegister(
proto.virtual_register().name(), proto.virtual_register().size(), interfered_registers);
}

}
}

Expand Down Expand Up @@ -102,7 +113,6 @@ CanonicalizedOperandProto ProtoFromInstructionOperand(
}

namespace {

std::vector<InstructionOperand> ToVector(
const google::protobuf::RepeatedPtrField<CanonicalizedOperandProto>&
protos) {
Expand Down
8 changes: 6 additions & 2 deletions gematria/basic_block/basic_block_protos_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,11 @@ TEST(BasicBlockFromProtoTest, VRegInstructions) {
canonicalized_instructions {
mnemonic: "CMP64RI32"
llvm_mnemonic: "CMP64ri32"
input_operands { virtual_register { name: "%60" size: 64 } }
input_operands {
virtual_register { name: "%60" size: 64 }
intefered_register: "%61"
intefered_register: "%62"
}
input_operands { immediate_value: 0 }
implicit_output_operands { register_name: "EFLAGS" }
}
Expand All @@ -250,7 +254,7 @@ TEST(BasicBlockFromProtoTest, VRegInstructions) {
/* mnemonic = */ "CMP64RI32", /* llvm_mnemonic = */ "CMP64ri32",
/* prefixes = */ {},
/* input_operands = */
{InstructionOperand::VirtualRegister("%60", 64),
{InstructionOperand::VirtualRegister("%60", 64, {"%61, %62"}),
InstructionOperand::ImmediateValue(0)},
/* implicit_input_operands = */ {},
/* output_operands = */ {},
Expand Down
2 changes: 1 addition & 1 deletion gematria/basic_block/python/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ PYBIND11_MODULE(basic_block, m) {
py::arg("fp_immediate_value"))
.def_static("from_virtual_register",
&InstructionOperand::VirtualRegister,
py::arg("register_name"), py::arg("size") = 0)
py::arg("register_name"), py::arg("size"), py::arg("interfered_registers"))
.def_static<InstructionOperand (*)(
std::string /* base_register */, int64_t /* displacement */,
std::string /* index_register */, int /* scaling */,
Expand Down
171 changes: 102 additions & 69 deletions gematria/datasets/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,17 @@ BHiveImporter::BHiveImporter(const Canonicalizer* canonicalizer)
// Append register definition line.
llvm::StringRef reg_name = MRI.getName(I);
name_to_reg_[reg_name.str()] = I;
// push itself to its own superreg2subreg_ list
superreg2subreg_[reg_name.str()].push_back(reg_name.str());
for (auto SuperReg : MRI.superregs(I)) {
if (MRI.isSubRegister(SuperReg, I)) {
llvm::StringRef super_reg_name = MRI.getName(SuperReg);
superreg2subreg_[super_reg_name.str()].push_back(reg_name.str());
}
}
}
// prettyPrintName2Reg();
prettyPrintSuperReg2SubReg();
}

absl::StatusOr<BasicBlockProto> BHiveImporter::BasicBlockProtoFromMachineCode(
Expand Down Expand Up @@ -247,7 +256,8 @@ absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseMIRCsvLine(
llvm::MachineBasicBlock* MBB = name_to_mbb_[MBB_name_ref];
std::string func_name = MBB->getParent()->getName().str();
assert(func_to_live_intervals_.find(func_name) !=
func_to_live_intervals_.end() && "Function not found in map");
func_to_live_intervals_.end() &&
"Function not found in map");
addInterferenceGraph(*block_proto_or_status,
func_to_live_intervals_[func_name],
func_to_live_intervals_[func_name]
Expand Down Expand Up @@ -519,98 +529,121 @@ static bool checkRegIntersectionsWithBBRange(
const BHiveImporter::RegLiveIntervals& reg_live_interval1,
const BHiveImporter::RegLiveIntervals& reg_live_interval2,
const BHiveImporter::BhiveLiveRange& bb_range) {
const BHiveImporter::BhiveLiveRange* range1HitsBB = nullptr;
for (auto& interval : reg_live_interval1.rangeList) {
if (areIntersected(interval, bb_range)) {
range1HitsBB = &interval;
}
const BHiveImporter::BhiveLiveRange* range1HitsBB = nullptr;
for (auto& interval : reg_live_interval1.rangeList) {
if (areIntersected(interval, bb_range)) {
range1HitsBB = &interval;
}
if (!range1HitsBB) {
return false;
}
for (auto& interval : reg_live_interval2.rangeList) {
if (areIntersected(interval, bb_range)) {
if (areIntersected(*range1HitsBB, interval)) {
return true;
}
}
if (!range1HitsBB) {
return false;
}
for (auto& interval : reg_live_interval2.rangeList) {
if (areIntersected(interval, bb_range)) {
if (areIntersected(*range1HitsBB, interval)) {
return true;
}
}
}
return false;
}

// void ToRepeatedPtrField(
// const std::vector<InstructionOperand>& operands,
// google::protobuf::RepeatedPtrField<CanonicalizedOperandProto>*
// repeated_field) {
// repeated_field->Reserve(operands.size());
// std::transform(operands.begin(), operands.end(),
// google::protobuf::RepeatedFieldBackInserter(repeated_field));
// }

void BHiveImporter::addInterferenceGraph(
BasicBlockProto& bb_proto,
BHiveImporter::FunctionLiveIntervalInfo& func_live_infos,
BHiveImporter::BhiveLiveRange& bb_range) {
std::set<std::string> live_virtual_registers;
std::set<std::string> live_physical_registers;

// helper function to update live_virtual_registers and
// live_physical_registers
auto update_live_regs = [&](const CanonicalizedOperandProto& operand) {
if (operand.operand_case() == CanonicalizedOperandProto::kVirtualRegister) {
live_virtual_registers.insert(operand.virtual_register().name());
} else if (operand.operand_case() ==
CanonicalizedOperandProto::kRegisterName) {
live_physical_registers.insert(operand.register_name());
}
};

auto add_interference = [&](CanonicalizedOperandProto& operand) {
if (operand.operand_case() == CanonicalizedOperandProto::kVirtualRegister) {
// add interference from other virtual registers to current operand
for (auto vReg : live_virtual_registers) {
if (vReg == operand.virtual_register().name()) continue;
assert(func_live_infos.virtual_register_live_range_func.find(vReg) !=
func_live_infos.virtual_register_live_range_func.end() &&
"Virtual register not found in map");
// If the live range of the two registers intersect, then add
// interference to proto
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func
[operand.virtual_register().name()],
func_live_infos.virtual_register_live_range_func[vReg],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(vReg));
}
}
// add interference from physical registers to current operand
for (auto pReg : live_physical_registers) {
auto subRegs = superreg2subreg_[pReg];
// if there's one subReg of Preg that has interference with current
// operand then add interference to proto
for (auto subReg : subRegs) {
if (func_live_infos.physical_register_live_range_func.find(subReg) ==
func_live_infos.physical_register_live_range_func.end())
continue;
// pretty print live range of subRegs
LOG("Live range of subReg: " << subReg);
for (auto& range :
func_live_infos.physical_register_live_range_func[subReg]
.rangeList) {
LOG(" " << range.first << ", " << range.second);
}
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func
[operand.virtual_register().name()],
func_live_infos.physical_register_live_range_func[subReg],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(pReg));
break;
}
}
}
}
};
// iterate over all operands in bb_proto, add virtual registers to
// live_virtual_registers
for (const auto& instruction : bb_proto.canonicalized_instructions()) {
for (const auto& operand : instruction.input_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
live_virtual_registers.insert(operand.virtual_register().name());
}
update_live_regs(operand);
}
for (const auto& operand : instruction.implicit_input_operands()) {
update_live_regs(operand);
}
for (const auto& operand : instruction.output_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
live_virtual_registers.insert(operand.virtual_register().name());
}
update_live_regs(operand);
}
for (const auto& operand : instruction.implicit_output_operands()) {
update_live_regs(operand);
}
}

// pretty print physical registers
LOG("Physical Registers: ");
for (auto& reg : live_physical_registers) {
LOG("Physical Register: " << reg);
}

// Iterate over all operands in bb_proto, add interference registers to each operand

// Iterate over all operands in bb_proto, add interference registers to each
// operand
for (auto& instruction : *bb_proto.mutable_canonicalized_instructions()) {
// LOG("before: " << instruction.DebugString());
for (auto& operand : *instruction.mutable_input_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
for (auto vRegs : live_virtual_registers){
if (vRegs == operand.virtual_register().name()) continue;
assert(func_live_infos.virtual_register_live_range_func.find(vRegs) !=
func_live_infos.virtual_register_live_range_func.end() &&
"Virtual register not found in map");
// If the live range of the two registers intersect, then add
// interference to proto
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func[
operand.virtual_register().name()],
func_live_infos.virtual_register_live_range_func[vRegs],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(vRegs));
}
}
}
add_interference(operand);
}
for (auto& operand : *instruction.mutable_output_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
for (auto vRegs : live_virtual_registers){
if (vRegs == operand.virtual_register().name()) continue;
assert(func_live_infos.virtual_register_live_range_func.find(vRegs) !=
func_live_infos.virtual_register_live_range_func.end() &&
"Virtual register not found in map");
// If the live range of the two registers intersect, then add
// interference to proto
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func[
operand.virtual_register().name()],
func_live_infos.virtual_register_live_range_func[vRegs],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(vRegs));
}
}
}
add_interference(operand);
}
// LOG("after: " << instruction.DebugString());
}
Expand Down
22 changes: 19 additions & 3 deletions gematria/datasets/bhive_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ class BHiveImporter {

// A struct that store all intervals in a function as well as ranges of BB
struct FunctionLiveIntervalInfo {
std::unordered_map<std::string, RegLiveIntervals> virtual_register_live_range_func;
std::unordered_map<std::string, RegLiveIntervals> physical_register_live_range_func;
std::unordered_map<std::string, RegLiveIntervals>
virtual_register_live_range_func;
std::unordered_map<std::string, RegLiveIntervals>
physical_register_live_range_func;
std::unordered_map<std::string, BhiveLiveRange> BBRangeList;
};

Expand All @@ -150,6 +152,17 @@ class BHiveImporter {
}
}

// pretty print superreg2subreg_
void prettyPrintSuperReg2SubReg() {
LOG("SuperReg2SubReg: ");
for (auto& [superreg, subreg] : superreg2subreg_) {
LOG(superreg << ": ");
for (auto& sub : subreg) {
LOG("\t" << sub);
}
}
}

// Now we are able to obtain the live range for each register
// We want to for each pair of regsiter, find out if their live range overlap
// Edge case 1: one live range may have multiple live ranges,
Expand All @@ -159,7 +172,9 @@ class BHiveImporter {
// to take in machine instruction/ fucntion
absl::StatusOr<bool> InteferenceGraphParser(std::string_view file_name);

void addInterferenceGraph(BasicBlockProto& bb_proto, FunctionLiveIntervalInfo& func_live_infos, BhiveLiveRange& bb_range);
void addInterferenceGraph(BasicBlockProto& bb_proto,
FunctionLiveIntervalInfo& func_live_infos,
BhiveLiveRange& bb_range);

private:
const Canonicalizer& canonicalizer_;
Expand All @@ -171,6 +186,7 @@ class BHiveImporter {
std::unordered_map<std::string, FunctionLiveIntervalInfo>
func_to_live_intervals_;
std::unordered_map<std::string, llvm::MCPhysReg> name_to_reg_;
std::unordered_map<std::string, std::vector<std::string>> superreg2subreg_;
llvm::LLVMContext llvm_context_;
std::unique_ptr<llvm::Module> mir_module_;
llvm::MachineModuleInfo MMI_;
Expand Down
4 changes: 2 additions & 2 deletions gematria/datasets/bhive_importer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ TEST_F(BHiveImporterTest, MIRDatasetTest2) {
x86_bhive_importer_->InteferenceGraphParser("sample_dataset/liveinfo"),
IsOk());
EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine(
kSourceName, "a,b,BB_21,2.37", 2, 3, kScaling),
IsOk());
kSourceName, "a,b,BB_21,2.37", 2, 3, kScaling),
IsOk());
}

} // namespace
Expand Down
Loading

0 comments on commit 981cdd3

Please sign in to comment.