diff --git a/gematria/basic_block/basic_block.cc b/gematria/basic_block/basic_block.cc index 2a65bafb..4047a61c 100644 --- a/gematria/basic_block/basic_block.cc +++ b/gematria/basic_block/basic_block.cc @@ -108,11 +108,12 @@ bool InstructionOperand::operator==(const InstructionOperand& other) const { } InstructionOperand InstructionOperand::VirtualRegister( - const std::string register_name, size_t size) { + const std::string register_name, size_t size, const std::vector& interfered_registers) { InstructionOperand result; result.type_ = OperandType::kVirtualRegister; result.register_name_ = std::move(register_name); result.size_ = size; + result.interfered_registers_ = std::move(interfered_registers); return result; } diff --git a/gematria/basic_block/basic_block.h b/gematria/basic_block/basic_block.h index e7a1de86..9dd8741f 100644 --- a/gematria/basic_block/basic_block.h +++ b/gematria/basic_block/basic_block.h @@ -147,7 +147,7 @@ class InstructionOperand { // The operands must be created through one of the factory functions. static InstructionOperand VirtualRegister(std::string register_name, - size_t size); + size_t size, const std::vector& interfered_registers); static InstructionOperand Register(std::string register_name); static InstructionOperand ImmediateValue(uint64_t immediate_value); static InstructionOperand FpImmediateValue(double fp_immediate_value); @@ -169,6 +169,11 @@ class InstructionOperand { // Returns the list of tokens representing this instruction. std::vector AsTokenList() const; + std::vector getInterferedRegisters() const { + assert(type_ == OperandType::kVirtualRegister); + return interfered_registers_; + } + // Returns a human-readable representation of the operand. // // This method implements the __str__() and __repr__() methods in the Python @@ -225,6 +230,7 @@ class InstructionOperand { double fp_immediate_value_ = 0.0; AddressTuple address_; int alias_group_id_ = 0; + std::vector interfered_registers_; }; std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand); diff --git a/gematria/basic_block/basic_block_protos.cc b/gematria/basic_block/basic_block_protos.cc index 98c5ef56..f8f88742 100644 --- a/gematria/basic_block/basic_block_protos.cc +++ b/gematria/basic_block/basic_block_protos.cc @@ -25,6 +25,13 @@ namespace gematria { +namespace { + std::vector ToVector( + const google::protobuf::RepeatedPtrField& protos) { + return std::vector(protos.begin(), protos.end()); + } +} + AddressTuple AddressTupleFromProto( const CanonicalizedOperandProto::AddressTuple& proto) { return AddressTuple( @@ -64,8 +71,12 @@ InstructionOperand InstructionOperandFromProto( return InstructionOperand::MemoryLocation( proto.memory().alias_group_id()); case CanonicalizedOperandProto::kVirtualRegister: - return InstructionOperand::VirtualRegister( - proto.virtual_register().name(), proto.virtual_register().size()); + { + std::vector interfered_registers = ToVector(proto.intefered_register()); + return InstructionOperand::VirtualRegister( + proto.virtual_register().name(), proto.virtual_register().size(), interfered_registers); + } + } } @@ -102,7 +113,6 @@ CanonicalizedOperandProto ProtoFromInstructionOperand( } namespace { - std::vector ToVector( const google::protobuf::RepeatedPtrField& protos) { diff --git a/gematria/basic_block/basic_block_protos_test.cc b/gematria/basic_block/basic_block_protos_test.cc index 60d9424e..006dec19 100644 --- a/gematria/basic_block/basic_block_protos_test.cc +++ b/gematria/basic_block/basic_block_protos_test.cc @@ -239,7 +239,11 @@ TEST(BasicBlockFromProtoTest, VRegInstructions) { canonicalized_instructions { mnemonic: "CMP64RI32" llvm_mnemonic: "CMP64ri32" - input_operands { virtual_register { name: "%60" size: 64 } } + input_operands { + virtual_register { name: "%60" size: 64 } + intefered_register: "%61" + intefered_register: "%62" + } input_operands { immediate_value: 0 } implicit_output_operands { register_name: "EFLAGS" } } @@ -250,7 +254,7 @@ TEST(BasicBlockFromProtoTest, VRegInstructions) { /* mnemonic = */ "CMP64RI32", /* llvm_mnemonic = */ "CMP64ri32", /* prefixes = */ {}, /* input_operands = */ - {InstructionOperand::VirtualRegister("%60", 64), + {InstructionOperand::VirtualRegister("%60", 64, {"%61, %62"}), InstructionOperand::ImmediateValue(0)}, /* implicit_input_operands = */ {}, /* output_operands = */ {}, diff --git a/gematria/basic_block/python/basic_block.cc b/gematria/basic_block/python/basic_block.cc index ad231734..fb4da159 100644 --- a/gematria/basic_block/python/basic_block.cc +++ b/gematria/basic_block/python/basic_block.cc @@ -156,7 +156,7 @@ PYBIND11_MODULE(basic_block, m) { py::arg("fp_immediate_value")) .def_static("from_virtual_register", &InstructionOperand::VirtualRegister, - py::arg("register_name"), py::arg("size") = 0) + py::arg("register_name"), py::arg("size"), py::arg("interfered_registers")) .def_static BHiveImporter::BasicBlockProtoFromMachineCode( @@ -247,7 +256,8 @@ absl::StatusOr BHiveImporter::ParseMIRCsvLine( llvm::MachineBasicBlock* MBB = name_to_mbb_[MBB_name_ref]; std::string func_name = MBB->getParent()->getName().str(); assert(func_to_live_intervals_.find(func_name) != - func_to_live_intervals_.end() && "Function not found in map"); + func_to_live_intervals_.end() && + "Function not found in map"); addInterferenceGraph(*block_proto_or_status, func_to_live_intervals_[func_name], func_to_live_intervals_[func_name] @@ -519,98 +529,121 @@ static bool checkRegIntersectionsWithBBRange( const BHiveImporter::RegLiveIntervals& reg_live_interval1, const BHiveImporter::RegLiveIntervals& reg_live_interval2, const BHiveImporter::BhiveLiveRange& bb_range) { - const BHiveImporter::BhiveLiveRange* range1HitsBB = nullptr; - for (auto& interval : reg_live_interval1.rangeList) { - if (areIntersected(interval, bb_range)) { - range1HitsBB = &interval; - } + const BHiveImporter::BhiveLiveRange* range1HitsBB = nullptr; + for (auto& interval : reg_live_interval1.rangeList) { + if (areIntersected(interval, bb_range)) { + range1HitsBB = &interval; } - if (!range1HitsBB) { - return false; - } - for (auto& interval : reg_live_interval2.rangeList) { - if (areIntersected(interval, bb_range)) { - if (areIntersected(*range1HitsBB, interval)) { - return true; - } + } + if (!range1HitsBB) { + return false; + } + for (auto& interval : reg_live_interval2.rangeList) { + if (areIntersected(interval, bb_range)) { + if (areIntersected(*range1HitsBB, interval)) { + return true; } } + } return false; } -// void ToRepeatedPtrField( -// const std::vector& operands, -// google::protobuf::RepeatedPtrField* -// repeated_field) { -// repeated_field->Reserve(operands.size()); -// std::transform(operands.begin(), operands.end(), -// google::protobuf::RepeatedFieldBackInserter(repeated_field)); -// } - void BHiveImporter::addInterferenceGraph( BasicBlockProto& bb_proto, BHiveImporter::FunctionLiveIntervalInfo& func_live_infos, BHiveImporter::BhiveLiveRange& bb_range) { std::set live_virtual_registers; + std::set live_physical_registers; + + // helper function to update live_virtual_registers and + // live_physical_registers + auto update_live_regs = [&](const CanonicalizedOperandProto& operand) { + if (operand.operand_case() == CanonicalizedOperandProto::kVirtualRegister) { + live_virtual_registers.insert(operand.virtual_register().name()); + } else if (operand.operand_case() == + CanonicalizedOperandProto::kRegisterName) { + live_physical_registers.insert(operand.register_name()); + } + }; + + auto add_interference = [&](CanonicalizedOperandProto& operand) { + if (operand.operand_case() == CanonicalizedOperandProto::kVirtualRegister) { + // add interference from other virtual registers to current operand + for (auto vReg : live_virtual_registers) { + if (vReg == operand.virtual_register().name()) continue; + assert(func_live_infos.virtual_register_live_range_func.find(vReg) != + func_live_infos.virtual_register_live_range_func.end() && + "Virtual register not found in map"); + // If the live range of the two registers intersect, then add + // interference to proto + if (checkRegIntersectionsWithBBRange( + func_live_infos.virtual_register_live_range_func + [operand.virtual_register().name()], + func_live_infos.virtual_register_live_range_func[vReg], + bb_range)) { + operand.mutable_intefered_register()->Add(std::move(vReg)); + } + } + // add interference from physical registers to current operand + for (auto pReg : live_physical_registers) { + auto subRegs = superreg2subreg_[pReg]; + // if there's one subReg of Preg that has interference with current + // operand then add interference to proto + for (auto subReg : subRegs) { + if (func_live_infos.physical_register_live_range_func.find(subReg) == + func_live_infos.physical_register_live_range_func.end()) + continue; + // pretty print live range of subRegs + LOG("Live range of subReg: " << subReg); + for (auto& range : + func_live_infos.physical_register_live_range_func[subReg] + .rangeList) { + LOG(" " << range.first << ", " << range.second); + } + if (checkRegIntersectionsWithBBRange( + func_live_infos.virtual_register_live_range_func + [operand.virtual_register().name()], + func_live_infos.physical_register_live_range_func[subReg], + bb_range)) { + operand.mutable_intefered_register()->Add(std::move(pReg)); + break; + } + } + } + } + }; // iterate over all operands in bb_proto, add virtual registers to // live_virtual_registers for (const auto& instruction : bb_proto.canonicalized_instructions()) { for (const auto& operand : instruction.input_operands()) { - if (operand.operand_case() == - CanonicalizedOperandProto::kVirtualRegister) { - live_virtual_registers.insert(operand.virtual_register().name()); - } + update_live_regs(operand); + } + for (const auto& operand : instruction.implicit_input_operands()) { + update_live_regs(operand); } for (const auto& operand : instruction.output_operands()) { - if (operand.operand_case() == - CanonicalizedOperandProto::kVirtualRegister) { - live_virtual_registers.insert(operand.virtual_register().name()); - } + update_live_regs(operand); } + for (const auto& operand : instruction.implicit_output_operands()) { + update_live_regs(operand); + } + } + + // pretty print physical registers + LOG("Physical Registers: "); + for (auto& reg : live_physical_registers) { + LOG("Physical Register: " << reg); } - - // Iterate over all operands in bb_proto, add interference registers to each operand + + // Iterate over all operands in bb_proto, add interference registers to each + // operand for (auto& instruction : *bb_proto.mutable_canonicalized_instructions()) { // LOG("before: " << instruction.DebugString()); for (auto& operand : *instruction.mutable_input_operands()) { - if (operand.operand_case() == - CanonicalizedOperandProto::kVirtualRegister) { - for (auto vRegs : live_virtual_registers){ - if (vRegs == operand.virtual_register().name()) continue; - assert(func_live_infos.virtual_register_live_range_func.find(vRegs) != - func_live_infos.virtual_register_live_range_func.end() && - "Virtual register not found in map"); - // If the live range of the two registers intersect, then add - // interference to proto - if (checkRegIntersectionsWithBBRange( - func_live_infos.virtual_register_live_range_func[ - operand.virtual_register().name()], - func_live_infos.virtual_register_live_range_func[vRegs], - bb_range)) { - operand.mutable_intefered_register()->Add(std::move(vRegs)); - } - } - } + add_interference(operand); } for (auto& operand : *instruction.mutable_output_operands()) { - if (operand.operand_case() == - CanonicalizedOperandProto::kVirtualRegister) { - for (auto vRegs : live_virtual_registers){ - if (vRegs == operand.virtual_register().name()) continue; - assert(func_live_infos.virtual_register_live_range_func.find(vRegs) != - func_live_infos.virtual_register_live_range_func.end() && - "Virtual register not found in map"); - // If the live range of the two registers intersect, then add - // interference to proto - if (checkRegIntersectionsWithBBRange( - func_live_infos.virtual_register_live_range_func[ - operand.virtual_register().name()], - func_live_infos.virtual_register_live_range_func[vRegs], - bb_range)) { - operand.mutable_intefered_register()->Add(std::move(vRegs)); - } - } - } + add_interference(operand); } // LOG("after: " << instruction.DebugString()); } diff --git a/gematria/datasets/bhive_importer.h b/gematria/datasets/bhive_importer.h index c11384b9..53344f64 100644 --- a/gematria/datasets/bhive_importer.h +++ b/gematria/datasets/bhive_importer.h @@ -139,8 +139,10 @@ class BHiveImporter { // A struct that store all intervals in a function as well as ranges of BB struct FunctionLiveIntervalInfo { - std::unordered_map virtual_register_live_range_func; - std::unordered_map physical_register_live_range_func; + std::unordered_map + virtual_register_live_range_func; + std::unordered_map + physical_register_live_range_func; std::unordered_map BBRangeList; }; @@ -150,6 +152,17 @@ class BHiveImporter { } } + // pretty print superreg2subreg_ + void prettyPrintSuperReg2SubReg() { + LOG("SuperReg2SubReg: "); + for (auto& [superreg, subreg] : superreg2subreg_) { + LOG(superreg << ": "); + for (auto& sub : subreg) { + LOG("\t" << sub); + } + } + } + // Now we are able to obtain the live range for each register // We want to for each pair of regsiter, find out if their live range overlap // Edge case 1: one live range may have multiple live ranges, @@ -159,7 +172,9 @@ class BHiveImporter { // to take in machine instruction/ fucntion absl::StatusOr InteferenceGraphParser(std::string_view file_name); - void addInterferenceGraph(BasicBlockProto& bb_proto, FunctionLiveIntervalInfo& func_live_infos, BhiveLiveRange& bb_range); + void addInterferenceGraph(BasicBlockProto& bb_proto, + FunctionLiveIntervalInfo& func_live_infos, + BhiveLiveRange& bb_range); private: const Canonicalizer& canonicalizer_; @@ -171,6 +186,7 @@ class BHiveImporter { std::unordered_map func_to_live_intervals_; std::unordered_map name_to_reg_; + std::unordered_map> superreg2subreg_; llvm::LLVMContext llvm_context_; std::unique_ptr mir_module_; llvm::MachineModuleInfo MMI_; diff --git a/gematria/datasets/bhive_importer_test.cc b/gematria/datasets/bhive_importer_test.cc index 2967a3d2..5d7b3b12 100644 --- a/gematria/datasets/bhive_importer_test.cc +++ b/gematria/datasets/bhive_importer_test.cc @@ -234,8 +234,8 @@ TEST_F(BHiveImporterTest, MIRDatasetTest2) { x86_bhive_importer_->InteferenceGraphParser("sample_dataset/liveinfo"), IsOk()); EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine( - kSourceName, "a,b,BB_21,2.37", 2, 3, kScaling), - IsOk()); + kSourceName, "a,b,BB_21,2.37", 2, 3, kScaling), + IsOk()); } } // namespace diff --git a/gematria/llvm/canonicalizer.cc b/gematria/llvm/canonicalizer.cc index e9ea840f..578aa240 100644 --- a/gematria/llvm/canonicalizer.cc +++ b/gematria/llvm/canonicalizer.cc @@ -72,10 +72,6 @@ void ReplaceExprOperands(llvm::MachineInstr& instruction) { for (int i = 0; i < instruction.getNumOperands(); ++i) { llvm::MachineOperand& operand = instruction.getOperand(i); if (operand.isSymbol() || operand.isGlobal() || operand.isCPI()) { - // TODO(ondrasej): In some cases the value may change the binary encoding - // of the instruction, e.g. switch between an 8-bit or a 32-bit encoding - // of the displacement and 0 might have a special meaning (e.g. do not use - // displacement at all). operand = llvm::MachineOperand::CreateImm(1); } } @@ -477,7 +473,7 @@ void X86Canonicalizer::AddOperand(const llvm::MachineInstr& mi, int operand_inde GetRegisterNameOrEmpty( mi.getOperand(operand_index + llvm::X86::AddrBaseReg), base_register, tmp_size); } else if (mi.getOperand(operand_index + llvm::X86::AddrBaseReg).isFI()){ - base_register = "rbp"; + base_register = "RBP"; } else { assert(false && "unsupported base register type"); LOG(mi); @@ -508,7 +504,7 @@ void X86Canonicalizer::AddOperand(const llvm::MachineInstr& mi, int operand_inde InstructionOperand::Register(name)); } else { operand_list.push_back( - InstructionOperand::VirtualRegister(name, size)); + InstructionOperand::VirtualRegister(name, size, {})); } } else if (operand.isImm()) { operand_list.push_back(