Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added bhive importer support for physical register interference info #8

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gematria/basic_block/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,12 @@ bool InstructionOperand::operator==(const InstructionOperand& other) const {
}

InstructionOperand InstructionOperand::VirtualRegister(
const std::string register_name, size_t size) {
const std::string register_name, size_t size, const std::vector<std::string>& interfered_registers) {
InstructionOperand result;
result.type_ = OperandType::kVirtualRegister;
result.register_name_ = std::move(register_name);
result.size_ = size;
result.interfered_registers_ = std::move(interfered_registers);
return result;
}

Expand Down
8 changes: 7 additions & 1 deletion gematria/basic_block/basic_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class InstructionOperand {

// The operands must be created through one of the factory functions.
static InstructionOperand VirtualRegister(std::string register_name,
size_t size);
size_t size, const std::vector<std::string>& interfered_registers);
static InstructionOperand Register(std::string register_name);
static InstructionOperand ImmediateValue(uint64_t immediate_value);
static InstructionOperand FpImmediateValue(double fp_immediate_value);
Expand All @@ -169,6 +169,11 @@ class InstructionOperand {
// Returns the list of tokens representing this instruction.
std::vector<std::string> AsTokenList() const;

std::vector<std::string> getInterferedRegisters() const {
assert(type_ == OperandType::kVirtualRegister);
return interfered_registers_;
}

// Returns a human-readable representation of the operand.
//
// This method implements the __str__() and __repr__() methods in the Python
Expand Down Expand Up @@ -225,6 +230,7 @@ class InstructionOperand {
double fp_immediate_value_ = 0.0;
AddressTuple address_;
int alias_group_id_ = 0;
std::vector<std::string> interfered_registers_;
};

std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand);
Expand Down
16 changes: 13 additions & 3 deletions gematria/basic_block/basic_block_protos.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@

namespace gematria {

namespace {
std::vector<std::string> ToVector(
const google::protobuf::RepeatedPtrField<std::string>& protos) {
return std::vector<std::string>(protos.begin(), protos.end());
}
}

AddressTuple AddressTupleFromProto(
const CanonicalizedOperandProto::AddressTuple& proto) {
return AddressTuple(
Expand Down Expand Up @@ -64,8 +71,12 @@ InstructionOperand InstructionOperandFromProto(
return InstructionOperand::MemoryLocation(
proto.memory().alias_group_id());
case CanonicalizedOperandProto::kVirtualRegister:
return InstructionOperand::VirtualRegister(
proto.virtual_register().name(), proto.virtual_register().size());
{
std::vector<std::string> interfered_registers = ToVector(proto.intefered_register());
return InstructionOperand::VirtualRegister(
proto.virtual_register().name(), proto.virtual_register().size(), interfered_registers);
}

}
}

Expand Down Expand Up @@ -102,7 +113,6 @@ CanonicalizedOperandProto ProtoFromInstructionOperand(
}

namespace {

std::vector<InstructionOperand> ToVector(
const google::protobuf::RepeatedPtrField<CanonicalizedOperandProto>&
protos) {
Expand Down
8 changes: 6 additions & 2 deletions gematria/basic_block/basic_block_protos_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,11 @@ TEST(BasicBlockFromProtoTest, VRegInstructions) {
canonicalized_instructions {
mnemonic: "CMP64RI32"
llvm_mnemonic: "CMP64ri32"
input_operands { virtual_register { name: "%60" size: 64 } }
input_operands {
virtual_register { name: "%60" size: 64 }
intefered_register: "%61"
intefered_register: "%62"
}
input_operands { immediate_value: 0 }
implicit_output_operands { register_name: "EFLAGS" }
}
Expand All @@ -250,7 +254,7 @@ TEST(BasicBlockFromProtoTest, VRegInstructions) {
/* mnemonic = */ "CMP64RI32", /* llvm_mnemonic = */ "CMP64ri32",
/* prefixes = */ {},
/* input_operands = */
{InstructionOperand::VirtualRegister("%60", 64),
{InstructionOperand::VirtualRegister("%60", 64, {"%61, %62"}),
InstructionOperand::ImmediateValue(0)},
/* implicit_input_operands = */ {},
/* output_operands = */ {},
Expand Down
2 changes: 1 addition & 1 deletion gematria/basic_block/python/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ PYBIND11_MODULE(basic_block, m) {
py::arg("fp_immediate_value"))
.def_static("from_virtual_register",
&InstructionOperand::VirtualRegister,
py::arg("register_name"), py::arg("size") = 0)
py::arg("register_name"), py::arg("size"), py::arg("interfered_registers"))
.def_static<InstructionOperand (*)(
std::string /* base_register */, int64_t /* displacement */,
std::string /* index_register */, int /* scaling */,
Expand Down
171 changes: 102 additions & 69 deletions gematria/datasets/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,17 @@ BHiveImporter::BHiveImporter(const Canonicalizer* canonicalizer)
// Append register definition line.
llvm::StringRef reg_name = MRI.getName(I);
name_to_reg_[reg_name.str()] = I;
// push itself to its own superreg2subreg_ list
superreg2subreg_[reg_name.str()].push_back(reg_name.str());
for (auto SuperReg : MRI.superregs(I)) {
if (MRI.isSubRegister(SuperReg, I)) {
llvm::StringRef super_reg_name = MRI.getName(SuperReg);
superreg2subreg_[super_reg_name.str()].push_back(reg_name.str());
}
}
}
// prettyPrintName2Reg();
prettyPrintSuperReg2SubReg();
}

absl::StatusOr<BasicBlockProto> BHiveImporter::BasicBlockProtoFromMachineCode(
Expand Down Expand Up @@ -247,7 +256,8 @@ absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseMIRCsvLine(
llvm::MachineBasicBlock* MBB = name_to_mbb_[MBB_name_ref];
std::string func_name = MBB->getParent()->getName().str();
assert(func_to_live_intervals_.find(func_name) !=
func_to_live_intervals_.end() && "Function not found in map");
func_to_live_intervals_.end() &&
"Function not found in map");
addInterferenceGraph(*block_proto_or_status,
func_to_live_intervals_[func_name],
func_to_live_intervals_[func_name]
Expand Down Expand Up @@ -519,98 +529,121 @@ static bool checkRegIntersectionsWithBBRange(
const BHiveImporter::RegLiveIntervals& reg_live_interval1,
const BHiveImporter::RegLiveIntervals& reg_live_interval2,
const BHiveImporter::BhiveLiveRange& bb_range) {
const BHiveImporter::BhiveLiveRange* range1HitsBB = nullptr;
for (auto& interval : reg_live_interval1.rangeList) {
if (areIntersected(interval, bb_range)) {
range1HitsBB = &interval;
}
const BHiveImporter::BhiveLiveRange* range1HitsBB = nullptr;
for (auto& interval : reg_live_interval1.rangeList) {
if (areIntersected(interval, bb_range)) {
range1HitsBB = &interval;
}
if (!range1HitsBB) {
return false;
}
for (auto& interval : reg_live_interval2.rangeList) {
if (areIntersected(interval, bb_range)) {
if (areIntersected(*range1HitsBB, interval)) {
return true;
}
}
if (!range1HitsBB) {
return false;
}
for (auto& interval : reg_live_interval2.rangeList) {
if (areIntersected(interval, bb_range)) {
if (areIntersected(*range1HitsBB, interval)) {
return true;
}
}
}
return false;
}

// void ToRepeatedPtrField(
// const std::vector<InstructionOperand>& operands,
// google::protobuf::RepeatedPtrField<CanonicalizedOperandProto>*
// repeated_field) {
// repeated_field->Reserve(operands.size());
// std::transform(operands.begin(), operands.end(),
// google::protobuf::RepeatedFieldBackInserter(repeated_field));
// }

void BHiveImporter::addInterferenceGraph(
BasicBlockProto& bb_proto,
BHiveImporter::FunctionLiveIntervalInfo& func_live_infos,
BHiveImporter::BhiveLiveRange& bb_range) {
std::set<std::string> live_virtual_registers;
std::set<std::string> live_physical_registers;

// helper function to update live_virtual_registers and
// live_physical_registers
auto update_live_regs = [&](const CanonicalizedOperandProto& operand) {
if (operand.operand_case() == CanonicalizedOperandProto::kVirtualRegister) {
live_virtual_registers.insert(operand.virtual_register().name());
} else if (operand.operand_case() ==
CanonicalizedOperandProto::kRegisterName) {
live_physical_registers.insert(operand.register_name());
}
};

auto add_interference = [&](CanonicalizedOperandProto& operand) {
if (operand.operand_case() == CanonicalizedOperandProto::kVirtualRegister) {
// add interference from other virtual registers to current operand
for (auto vReg : live_virtual_registers) {
if (vReg == operand.virtual_register().name()) continue;
assert(func_live_infos.virtual_register_live_range_func.find(vReg) !=
func_live_infos.virtual_register_live_range_func.end() &&
"Virtual register not found in map");
// If the live range of the two registers intersect, then add
// interference to proto
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func
[operand.virtual_register().name()],
func_live_infos.virtual_register_live_range_func[vReg],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(vReg));
}
}
// add interference from physical registers to current operand
for (auto pReg : live_physical_registers) {
auto subRegs = superreg2subreg_[pReg];
// if there's one subReg of Preg that has interference with current
// operand then add interference to proto
for (auto subReg : subRegs) {
if (func_live_infos.physical_register_live_range_func.find(subReg) ==
func_live_infos.physical_register_live_range_func.end())
continue;
// pretty print live range of subRegs
LOG("Live range of subReg: " << subReg);
for (auto& range :
func_live_infos.physical_register_live_range_func[subReg]
.rangeList) {
LOG(" " << range.first << ", " << range.second);
}
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func
[operand.virtual_register().name()],
func_live_infos.physical_register_live_range_func[subReg],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(pReg));
break;
}
}
}
}
};
// iterate over all operands in bb_proto, add virtual registers to
// live_virtual_registers
for (const auto& instruction : bb_proto.canonicalized_instructions()) {
for (const auto& operand : instruction.input_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
live_virtual_registers.insert(operand.virtual_register().name());
}
update_live_regs(operand);
}
for (const auto& operand : instruction.implicit_input_operands()) {
update_live_regs(operand);
}
for (const auto& operand : instruction.output_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
live_virtual_registers.insert(operand.virtual_register().name());
}
update_live_regs(operand);
}
for (const auto& operand : instruction.implicit_output_operands()) {
update_live_regs(operand);
}
}

// pretty print physical registers
LOG("Physical Registers: ");
for (auto& reg : live_physical_registers) {
LOG("Physical Register: " << reg);
}

// Iterate over all operands in bb_proto, add interference registers to each operand

// Iterate over all operands in bb_proto, add interference registers to each
// operand
for (auto& instruction : *bb_proto.mutable_canonicalized_instructions()) {
// LOG("before: " << instruction.DebugString());
for (auto& operand : *instruction.mutable_input_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
for (auto vRegs : live_virtual_registers){
if (vRegs == operand.virtual_register().name()) continue;
assert(func_live_infos.virtual_register_live_range_func.find(vRegs) !=
func_live_infos.virtual_register_live_range_func.end() &&
"Virtual register not found in map");
// If the live range of the two registers intersect, then add
// interference to proto
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func[
operand.virtual_register().name()],
func_live_infos.virtual_register_live_range_func[vRegs],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(vRegs));
}
}
}
add_interference(operand);
}
for (auto& operand : *instruction.mutable_output_operands()) {
if (operand.operand_case() ==
CanonicalizedOperandProto::kVirtualRegister) {
for (auto vRegs : live_virtual_registers){
if (vRegs == operand.virtual_register().name()) continue;
assert(func_live_infos.virtual_register_live_range_func.find(vRegs) !=
func_live_infos.virtual_register_live_range_func.end() &&
"Virtual register not found in map");
// If the live range of the two registers intersect, then add
// interference to proto
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func[
operand.virtual_register().name()],
func_live_infos.virtual_register_live_range_func[vRegs],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(vRegs));
}
}
}
add_interference(operand);
}
// LOG("after: " << instruction.DebugString());
}
Expand Down
22 changes: 19 additions & 3 deletions gematria/datasets/bhive_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ class BHiveImporter {

// A struct that store all intervals in a function as well as ranges of BB
struct FunctionLiveIntervalInfo {
std::unordered_map<std::string, RegLiveIntervals> virtual_register_live_range_func;
std::unordered_map<std::string, RegLiveIntervals> physical_register_live_range_func;
std::unordered_map<std::string, RegLiveIntervals>
virtual_register_live_range_func;
std::unordered_map<std::string, RegLiveIntervals>
physical_register_live_range_func;
std::unordered_map<std::string, BhiveLiveRange> BBRangeList;
};

Expand All @@ -150,6 +152,17 @@ class BHiveImporter {
}
}

// pretty print superreg2subreg_
void prettyPrintSuperReg2SubReg() {
LOG("SuperReg2SubReg: ");
for (auto& [superreg, subreg] : superreg2subreg_) {
LOG(superreg << ": ");
for (auto& sub : subreg) {
LOG("\t" << sub);
}
}
}

// Now we are able to obtain the live range for each register
// We want to for each pair of regsiter, find out if their live range overlap
// Edge case 1: one live range may have multiple live ranges,
Expand All @@ -159,7 +172,9 @@ class BHiveImporter {
// to take in machine instruction/ fucntion
absl::StatusOr<bool> InteferenceGraphParser(std::string_view file_name);

void addInterferenceGraph(BasicBlockProto& bb_proto, FunctionLiveIntervalInfo& func_live_infos, BhiveLiveRange& bb_range);
void addInterferenceGraph(BasicBlockProto& bb_proto,
FunctionLiveIntervalInfo& func_live_infos,
BhiveLiveRange& bb_range);

private:
const Canonicalizer& canonicalizer_;
Expand All @@ -171,6 +186,7 @@ class BHiveImporter {
std::unordered_map<std::string, FunctionLiveIntervalInfo>
func_to_live_intervals_;
std::unordered_map<std::string, llvm::MCPhysReg> name_to_reg_;
std::unordered_map<std::string, std::vector<std::string>> superreg2subreg_;
llvm::LLVMContext llvm_context_;
std::unique_ptr<llvm::Module> mir_module_;
llvm::MachineModuleInfo MMI_;
Expand Down
4 changes: 2 additions & 2 deletions gematria/datasets/bhive_importer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ TEST_F(BHiveImporterTest, MIRDatasetTest2) {
x86_bhive_importer_->InteferenceGraphParser("sample_dataset/liveinfo"),
IsOk());
EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine(
kSourceName, "a,b,BB_21,2.37", 2, 3, kScaling),
IsOk());
kSourceName, "a,b,BB_21,2.37", 2, 3, kScaling),
IsOk());
}

} // namespace
Expand Down
Loading
Loading