Skip to content

Commit

Permalink
Added interfered registers in AddressTuple
Browse files Browse the repository at this point in the history
  • Loading branch information
9Tempest committed Dec 7, 2023
1 parent 0d0e79b commit eff6be5
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 56 deletions.
11 changes: 7 additions & 4 deletions gematria/basic_block/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,20 @@ InstructionOperand InstructionOperand::Address(std::string base_register,
int64_t displacement,
std::string index_register,
int scaling,
std::string segment_register) {
std::string segment_register,
int base_register_size,
int index_register_size,
int segment_register_size) {
InstructionOperand result;
result.type_ = OperandType::kAddress;
result.address_.base_register = std::move(base_register);
result.address_.index_register = std::move(index_register);
result.address_.displacement = displacement;
result.address_.scaling = scaling;
result.address_.segment_register = segment_register;
result.address_.base_register_size = 64;
result.address_.index_register_size = 64;
result.address_.segment_register_size = 64;
result.address_.base_register_size = base_register_size;
result.address_.index_register_size = index_register_size;
result.address_.segment_register_size = segment_register_size;
result.address_.base_register_intefered_register = {};
result.address_.index_register_intefered_register = {};
result.address_.segment_register_intefered_register = {};
Expand Down
11 changes: 7 additions & 4 deletions gematria/basic_block/basic_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ struct AddressTuple {
// for the instruction is used.
std::string segment_register;
// The size of the base register. Used only when base_register is non-empty.
int base_register_size;
size_t base_register_size;
// The size of the index register. Used only when index_register is non-empty.
int index_register_size;
size_t index_register_size;
// The size of the segment register. Used only when segment_register is
int segment_register_size;
size_t segment_register_size;

// The name of the index register of the address. When empty, index register
std::vector<std::string> base_register_intefered_register;
Expand Down Expand Up @@ -179,7 +179,10 @@ class InstructionOperand {
static InstructionOperand Address(std::string base_register,
int64_t displacement,
std::string index_register, int scaling,
std::string segment_register);
std::string segment_register,
int base_register_size = 64,
int index_register_size = 64,
int segment_register_size = 64);
static InstructionOperand MemoryLocation(int alias_group_id);

bool operator==(const InstructionOperand&) const;
Expand Down
6 changes: 4 additions & 2 deletions gematria/basic_block/python/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,13 @@ PYBIND11_MODULE(basic_block, m) {
.def_static<InstructionOperand (*)(
std::string /* base_register */, int64_t /* displacement */,
std::string /* index_register */, int /* scaling */,
std::string /* segment_register */)>(
std::string /* segment_register */, int /* base_register_size */,
int /* index_register_size */, int /* segment_register_size */)>(
"from_address", &InstructionOperand::Address,
py::arg("base_register") = std::string(), py::arg("displacement") = 0,
py::arg("index_register") = std::string(), py::arg("scaling") = 0,
py::arg("segment_register") = std::string())
py::arg("segment_register") = std::string(), py::arg("base_register_size") = 64,
py::arg("index_register_size") = 64, py::arg("segment_register_size") = 64)
.def_static<InstructionOperand (*)(AddressTuple)>(
"from_address", &InstructionOperand::Address,
py::arg("address_tuple"))
Expand Down
128 changes: 89 additions & 39 deletions gematria/datasets/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseMIRCsvLine(

absl::StatusOr<bool> BHiveImporter::LoadMIRModule(std::string_view file_name) {
// clear previous loaded module
func_to_live_intervals_.clear();
name_to_mbb_.clear();
if (mir_module_) {
for (llvm::Function& F : mir_module_->functions()) {
Expand Down Expand Up @@ -563,52 +564,101 @@ void BHiveImporter::addInterferenceGraph(
} else if (operand.operand_case() ==
CanonicalizedOperandProto::kRegisterName) {
live_physical_registers.insert(operand.register_name());
} else if (operand.operand_case() == CanonicalizedOperandProto::kAddress) {
if (!operand.address().base_register().empty()) {
if (operand.address().base_register()[0] == '%') {
live_virtual_registers.insert(operand.address().base_register());
} else {
live_physical_registers.insert(operand.address().base_register());
}
}
if (!operand.address().index_register().empty()) {
if (operand.address().index_register()[0] == '%') {
live_virtual_registers.insert(operand.address().index_register());
} else {
live_physical_registers.insert(operand.address().index_register());
}
}
if (!operand.address().segment().empty()) {
if (operand.address().segment()[0] == '%') {
live_virtual_registers.insert(operand.address().segment());
} else {
live_physical_registers.insert(operand.address().segment());
}
}
}
};

auto add_interference = [&](CanonicalizedOperandProto& operand) {
if (operand.operand_case() == CanonicalizedOperandProto::kVirtualRegister) {
// add interference from other virtual registers to current operand
for (auto vReg : live_virtual_registers) {
if (vReg == operand.virtual_register().name()) continue;
assert(func_live_infos.virtual_register_live_range_func.find(vReg) !=
func_live_infos.virtual_register_live_range_func.end() &&
"Virtual register not found in map");
// If the live range of the two registers intersect, then add
// interference to proto
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func
[operand.virtual_register().name()],
func_live_infos.virtual_register_live_range_func[vReg],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(vReg));
}
}
// add interference from physical registers to current operand
for (auto pReg : live_physical_registers) {
auto subRegs = superreg2subreg_[pReg];
// if there's one subReg of Preg that has interference with current
// operand then add interference to proto
for (auto subReg : subRegs) {
if (func_live_infos.physical_register_live_range_func.find(subReg) ==
func_live_infos.physical_register_live_range_func.end())
continue;
// pretty print live range of subRegs
LOG("Live range of subReg: " << subReg);
for (auto& range :
func_live_infos.physical_register_live_range_func[subReg]
.rangeList) {
LOG(" " << range.first << ", " << range.second);
}
auto add_interference_on_name =
[&](const std::string& name,
google::protobuf::RepeatedPtrField<std::string>*
mutable_intefered_register) {
for (auto vReg : live_virtual_registers) {
if (vReg == name) continue;
assert(func_live_infos.virtual_register_live_range_func.find(vReg) !=
func_live_infos.virtual_register_live_range_func.end() &&
"Virtual register not found in map");
// If the live range of the two registers intersect, then add
// interference to proto
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func
[operand.virtual_register().name()],
func_live_infos.physical_register_live_range_func[subReg],
func_live_infos.virtual_register_live_range_func[name],
func_live_infos.virtual_register_live_range_func[vReg],
bb_range)) {
operand.mutable_intefered_register()->Add(std::move(pReg));
break;
mutable_intefered_register->Add(std::move(vReg));
}
}
// add interference from physical registers to current operand
for (auto pReg : live_physical_registers) {
auto subRegs = superreg2subreg_[pReg];
// if there's one subReg of Preg that has interference with current
// operand then add interference to proto
for (auto subReg : subRegs) {
if (func_live_infos.physical_register_live_range_func.find(
subReg) ==
func_live_infos.physical_register_live_range_func.end())
continue;
// pretty print live range of subRegs
LOG("Live range of subReg: " << subReg);
for (auto& range :
func_live_infos.physical_register_live_range_func[subReg]
.rangeList) {
LOG(" " << range.first << ", " << range.second);
}
if (checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func[name],
func_live_infos.physical_register_live_range_func[subReg],
bb_range)) {
mutable_intefered_register->Add(std::move(pReg));
break;
}
}
}
};

auto add_interference = [&](CanonicalizedOperandProto& operand) {
if (operand.operand_case() == CanonicalizedOperandProto::kVirtualRegister) {
add_interference_on_name(operand.virtual_register().name(),
operand.mutable_intefered_register());
} else if (operand.operand_case() == CanonicalizedOperandProto::kAddress) {
if (!operand.address().base_register().empty() &&
operand.address().base_register()[0] == '%') {
add_interference_on_name(
operand.address().base_register(),
operand.mutable_address()
->mutable_base_register_intefered_register());
}
if (!operand.address().index_register().empty() &&
operand.address().index_register()[0] == '%') {
add_interference_on_name(
operand.address().index_register(),
operand.mutable_address()
->mutable_index_register_intefered_register());
}
if (!operand.address().segment().empty() &&
operand.address().segment()[0] == '%') {
add_interference_on_name(
operand.address().segment(),
operand.mutable_address()->mutable_segment_intefered_register());
}
}
};
Expand Down
12 changes: 12 additions & 0 deletions gematria/datasets/bhive_importer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,5 +238,17 @@ TEST_F(BHiveImporterTest, MIRDatasetTest2) {
IsOk());
}

TEST_F(BHiveImporterTest, MIRAddressTupleTest) {
EXPECT_THAT(
x86_bhive_importer_->LoadMIRModule("mir_input/test_mir_input/AdaptiveMaxPooling2d.mir"),
IsOk());
EXPECT_THAT(
x86_bhive_importer_->InteferenceGraphParser("mir_input/test_mir_input/AdaptiveMaxPooling2d.liveinfo"),
IsOk());
EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine(
kSourceName, "a,b,BB_27,2.37", 2, 3, kScaling),
IsOk());
}

} // namespace
} // namespace gematria
2 changes: 1 addition & 1 deletion gematria/datasets/python/import_from_mir.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def main(argv: Sequence[str]) -> None:
BB_name = line.split(",")[_MACHINE_BASIC_BLOCK_NAME_COLUMN_INDEX.value]
through_put = line.split(",")[_THROUGHPUT_COLUMN_INDEX.value]
# skip blocks with throughput -1
if float(through_put) == -1:
if float(through_put) == -1 or float(through_put) < 0.1:
num_skipped_blocks += 1
continue
block_proto = importer.ParseMIRCsvLine(
Expand Down
18 changes: 12 additions & 6 deletions gematria/llvm/canonicalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include <sstream>

#define DEBUG

#ifdef DEBUG
#define LOG(X) \
Expand Down Expand Up @@ -121,7 +122,7 @@ bool Canonicalizer::GetRegisterNameOrEmpty(
const llvm::MachineFunction *MF = operand.getParent()->getParent()->getParent();
const llvm::TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
const llvm::MachineRegisterInfo &MRI = MF->getRegInfo();
unsigned Size = TRI->getRegSizeInBits(reg, MRI);
size_t Size = TRI->getRegSizeInBits(reg, MRI);
name = "%" + std::to_string(llvm::Register::virtReg2Index(reg));
size = Size;
return false;
Expand Down Expand Up @@ -468,10 +469,12 @@ void X86Canonicalizer::AddOperand(const llvm::MachineInstr& mi, int operand_inde
: instruction.input_operands;
if (is_address_computation_tuple) { // TODO: Check if MIR has address computation tuple
std::string base_register;
size_t tmp_size;
size_t base_register_size = 64;
size_t index_register_size = 64;
size_t segment_register_size = 64;
if (mi.getOperand(operand_index + llvm::X86::AddrBaseReg).isReg()){
GetRegisterNameOrEmpty(
mi.getOperand(operand_index + llvm::X86::AddrBaseReg), base_register, tmp_size);
mi.getOperand(operand_index + llvm::X86::AddrBaseReg), base_register, base_register_size);
} else if (mi.getOperand(operand_index + llvm::X86::AddrBaseReg).isFI()){
base_register = "RBP";
} else {
Expand All @@ -482,18 +485,21 @@ void X86Canonicalizer::AddOperand(const llvm::MachineInstr& mi, int operand_inde
mi.getOperand(operand_index + llvm::X86::AddrDisp).getImm();
std::string index_register;
GetRegisterNameOrEmpty(
mi.getOperand(operand_index + llvm::X86::AddrIndexReg), index_register, tmp_size);
mi.getOperand(operand_index + llvm::X86::AddrIndexReg), index_register, index_register_size);
const int64_t scaling =
mi.getOperand(operand_index + llvm::X86::AddrScaleAmt).getImm();
std::string segment_register;
GetRegisterNameOrEmpty(
mi.getOperand(operand_index + llvm::X86::AddrSegmentReg), segment_register, tmp_size);
mi.getOperand(operand_index + llvm::X86::AddrSegmentReg), segment_register, segment_register_size);
operand_list.push_back(InstructionOperand::Address(
/* base_register= */ std::move(base_register),
/* displacement= */ displacement,
/* index_register= */ std::move(index_register),
/* scaling= */ static_cast<int>(scaling),
/* segment_register= */ std::move(segment_register)));
/* segment_register= */ std::move(segment_register),
/* base_register_size= */ base_register_size,
/* index_register_size= */ index_register_size,
/* segment_register_size= */ segment_register_size));
LOG("Hit here address_computation_tuple reg " << mi << "\n");
} else if (operand.isReg()) {
std::string name;
Expand Down

0 comments on commit eff6be5

Please sign in to comment.