Skip to content

Commit

Permalink
fixed liveinfo parser bug && Added options in impor_from_mir (perBB, …
Browse files Browse the repository at this point in the history
…perFunc)
  • Loading branch information
9Tempest committed Dec 7, 2023
1 parent b9b0167 commit a51a79f
Show file tree
Hide file tree
Showing 12 changed files with 16,153 additions and 75 deletions.
105 changes: 56 additions & 49 deletions gematria/datasets/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ constexpr int kDefaultSyntax = 0;

} // namespace

BHiveImporter::BHiveImporter(const Canonicalizer* canonicalizer)
BHiveImporter::BHiveImporter(const Canonicalizer* canonicalizer, const std::string& model_type)
: canonicalizer_(*ABSL_DIE_IF_NULL(canonicalizer)),
target_machine_(canonicalizer->target_machine()),
context_(std::make_unique<llvm::MCContext>(
Expand All @@ -81,11 +81,11 @@ BHiveImporter::BHiveImporter(const Canonicalizer* canonicalizer)
*target_machine_.getMCAsmInfo(), *target_machine_.getMCInstrInfo(),
*target_machine_.getMCRegisterInfo())),
MMI_(dynamic_cast<const llvm::LLVMTargetMachine*>(&target_machine_)) {
// setup super register to sub register mapping
const llvm::MCRegisterInfo& MRI = *target_machine_.getMCRegisterInfo();
for (llvm::MCPhysReg I = 1, E = MRI.getNumRegs(); I != E; ++I) {
// Append register definition line.
llvm::StringRef reg_name = MRI.getName(I);
name_to_reg_[reg_name.str()] = I;
// push itself to its own superreg2subreg_ list
superreg2subreg_[reg_name.str()].push_back(reg_name.str());
for (auto SuperReg : MRI.superregs(I)) {
Expand All @@ -95,6 +95,15 @@ BHiveImporter::BHiveImporter(const Canonicalizer* canonicalizer)
}
}
}
// set up model type
if (model_type == "PER_BB_LIVE_INFO") {
model_type_ = MODEL_TYPE::PER_BB_LIVE_INFO;
} else if (model_type == "PER_FUNC_LIVE_INFO") {
model_type_ = MODEL_TYPE::PER_FUNC_LIVE_INFO;
} else {
model_type_ = MODEL_TYPE::NO_LIVE_INFO;
}
LOG("model type is " << model_type_ << " raw string is " << model_type);
// prettyPrintName2Reg();
// prettyPrintSuperReg2SubReg();
}
Expand Down Expand Up @@ -180,19 +189,9 @@ absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseBHiveCsvLine(
return proto;
}

absl::StatusOr<BasicBlockProto> BHiveImporter::BasicBlockProtoFromMBBName(
std::string_view MBB_name, uint64_t base_address /*= 0*/) {
absl::StatusOr<BasicBlockProto> BHiveImporter::BasicBlockProtoFromMBB(
llvm::MachineBasicBlock* MBB, uint64_t base_address /*= 0*/) {
BasicBlockProto basic_block_proto;
// convert MBB_name to llvm::StringRef
llvm::StringRef MBB_name_ref(MBB_name.data(), MBB_name.size());

// lookup the MBB in the map, if not, return error
if (name_to_mbb_.find(MBB_name_ref) == name_to_mbb_.end()) {
return absl::InvalidArgumentError(
absl::StrCat("Could not find MBB with name ", MBB_name));
}

llvm::MachineBasicBlock* MBB = name_to_mbb_[MBB_name_ref];
LOG("MBB is " << *MBB);
for (llvm::MachineInstr& MI : *MBB) {
// if MI is a control instruction(ret,branch,jmp), skip it
Expand Down Expand Up @@ -241,31 +240,37 @@ absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseMIRCsvLine(
const std::string_view throughput_str = columns[throughput_column_index];

BasicBlockWithThroughputProto proto;

absl::StatusOr<BasicBlockProto> block_proto_or_status =
BasicBlockProtoFromMBBName(BB_unique_name, base_address);
if (!block_proto_or_status.ok()) return block_proto_or_status.status();

// convert MBB_name to llvm::StringRef
llvm::StringRef MBB_name_ref(BB_unique_name.data(), BB_unique_name.size());

// lookup the MBB in the map, if not, return error
if (name_to_mbb_.find(MBB_name_ref) == name_to_mbb_.end()) {
return absl::InvalidArgumentError(
absl::StrCat("Could not find MBB with name ", BB_unique_name));
}

llvm::MachineBasicBlock* MBB = name_to_mbb_[MBB_name_ref];
std::string func_name = MBB->getParent()->getName().str();
assert(func_to_live_intervals_.find(func_name) !=
func_to_live_intervals_.end() &&
"Function not found in map");
auto instrument_result = addInterferenceGraph(
*block_proto_or_status, func_to_live_intervals_[func_name],
func_to_live_intervals_[func_name]
.BBRangeList[std::string(BB_unique_name)]);
if (!instrument_result.ok()) {
return absl::InvalidArgumentError(absl::StrCat(
"Could not instrument interference graph for BB ", BB_unique_name));

absl::StatusOr<BasicBlockProto> block_proto_or_status =
BasicBlockProtoFromMBB(MBB, base_address);
if (!block_proto_or_status.ok()) return block_proto_or_status.status();

// Add inteference graph based on model type
if (model_type_ != MODEL_TYPE::NO_LIVE_INFO){
std::string func_name = MBB->getParent()->getName().str();
assert(func_to_live_intervals_.find(func_name) !=
func_to_live_intervals_.end() &&
"Function not found in map");
auto instrument_result = addInterferenceGraph(
*block_proto_or_status, func_to_live_intervals_[func_name],
func_to_live_intervals_[func_name]
.BBRangeList[std::string(BB_unique_name)]);
if (!instrument_result.ok()) {
return absl::InvalidArgumentError(absl::StrCat(
"Could not instrument interference graph for BB ", BB_unique_name));
}
}

*proto.mutable_basic_block() = std::move(block_proto_or_status).value();

double throughput_cycles = 0.0;
Expand Down Expand Up @@ -444,10 +449,6 @@ absl::StatusOr<bool> BHiveImporter::InteferenceGraphParser(
lineStream >> dummy >> start >> dummy >> dummy >> end >> dummy >>
dummy >> discard >> dummy;

// Print out information for debug
// std::cerr << "Register: " << currentRegister << ", " << start << ", "
// << end << "\n";

// Since LLVM do not support [] operator we need to find it first
auto resultRegLiveIntervals =
(is_virtual)
Expand Down Expand Up @@ -506,11 +507,12 @@ absl::StatusOr<bool> BHiveImporter::InteferenceGraphParser(
lineStream >> start >> dummy >> end;

info->BBRangeList[currentBB] = {start, end};
isParsingRegister = false;
}

// In this case, we arrived at the definition of a new function
// In this case we need to
else if (line[0] == '_') {
else {
// We reached the end of a function, add info to the Map
// If this is the beginning of a new function, just add
// a dummy value and delete it at the end
Expand Down Expand Up @@ -635,13 +637,6 @@ absl::StatusOr<bool> BHiveImporter::addInterferenceGraph(
subReg) ==
func_live_infos.physical_register_live_range_func.end())
continue;
// pretty print live range of subRegs
// LOG("Live range of subReg: " << subReg);
// for (auto& range :
// func_live_infos.physical_register_live_range_func[subReg]
// .rangeList) {
// LOG(" " << range.first << ", " << range.second);
// }
auto check_result = checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func[name],
func_live_infos.physical_register_live_range_func[subReg],
Expand All @@ -654,6 +649,23 @@ absl::StatusOr<bool> BHiveImporter::addInterferenceGraph(
}
}
}
// if model_type is PER_FUNCTION_LIVE_INFO, then we need to add
// interference from the whole function
if (model_type_ == MODEL_TYPE::PER_FUNC_LIVE_INFO) {
for (auto& [vReg, liveInterval] :
func_live_infos.virtual_register_live_range_func) {
if (live_virtual_registers.count(vReg)) continue;
// LOG("Adding interference from function " << vReg);
auto check_result = checkRegIntersectionsWithBBRange(
func_live_infos.virtual_register_live_range_func[name],
liveInterval, bb_range);
if (!check_result.ok()) return check_result;
if (*check_result) {
mutable_intefered_register->Add(std::string(vReg));
mutable_intefered_register_size->Add(32);
}
} // for
}
return absl::StatusOr<bool>(true);
};

Expand Down Expand Up @@ -712,12 +724,6 @@ absl::StatusOr<bool> BHiveImporter::addInterferenceGraph(
}
}

// pretty print physical registers
// LOG("Physical Registers: ");
// for (auto& reg : live_physical_registers) {
// LOG("Physical Register: " << reg);
// }

// Iterate over all operands in bb_proto, add interference registers to each
// operand
for (auto& instruction : *bb_proto.mutable_canonicalized_instructions()) {
Expand All @@ -732,7 +738,8 @@ absl::StatusOr<bool> BHiveImporter::addInterferenceGraph(
}
// LOG("after: " << instruction.DebugString());
}
// printMap(func_to_live_intervals_);
return true;
}
} // addInterferenceGraph

} // namespace gematria
23 changes: 9 additions & 14 deletions gematria/datasets/bhive_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,15 @@ namespace gematria {
// Parser for BHive CSV files.
class BHiveImporter {
public:
enum MODEL_TYPE{
NO_LIVE_INFO,
PER_BB_LIVE_INFO,
PER_FUNC_LIVE_INFO,
};
// Creates a new BHive importer from a given canonicalizer. The canonicalizer
// must be for the architecture/microarchitecture of the data set.
// Does not take ownership of the canonicalizer.
explicit BHiveImporter(const Canonicalizer* canonicalizer);
explicit BHiveImporter(const Canonicalizer* canonicalizer, const std::string& model_type = "NO_LIVE_INFO");

// Creates a basic block from the given block of machine code. `machine_code`
// must contain machine code of the instructions to include in the basic
Expand All @@ -94,8 +99,8 @@ class BHiveImporter {
absl::StatusOr<BasicBlockProto> BasicBlockProtoFromMachineCodeHex(
std::string_view machine_code_hex, uint64_t base_address = 0);

absl::StatusOr<BasicBlockProto> BasicBlockProtoFromMBBName(
std::string_view MBB_name, uint64_t base_address = 0);
absl::StatusOr<BasicBlockProto> BasicBlockProtoFromMBB(
llvm:: MachineBasicBlock* MBB, uint64_t base_address = 0);

// Parses a basic block with throughput from one BHive CSV line. Expects that
// the line has the format "{machine_code},{throughput}" where {machine_code}
Expand Down Expand Up @@ -146,12 +151,6 @@ class BHiveImporter {
std::unordered_map<std::string, BhiveLiveRange> BBRangeList;
};

void prettyPrintName2Reg() {
for (auto& [name, reg] : name_to_reg_) {
LOG(name << " " << reg);
}
}

// pretty print superreg2subreg_
void prettyPrintSuperReg2SubReg() {
LOG("SuperReg2SubReg: ");
Expand Down Expand Up @@ -185,16 +184,12 @@ class BHiveImporter {
llvm::DenseMap<llvm::StringRef, llvm::MachineBasicBlock*> name_to_mbb_;
std::unordered_map<std::string, FunctionLiveIntervalInfo>
func_to_live_intervals_;
std::unordered_map<std::string, llvm::MCPhysReg> name_to_reg_;
std::unordered_map<std::string, std::vector<std::string>> superreg2subreg_;
llvm::LLVMContext llvm_context_;
std::unique_ptr<llvm::Module> mir_module_;
llvm::MachineModuleInfo MMI_;
std::unique_ptr<llvm::MIRParser> mir_parser_;

// Author: Zhan Shi
// Add one data strcture to the bhiveimporter storing interference graph
llvm::DenseMap<llvm::StringRef, llvm::MachineBasicBlock*> name_to_graph_;
MODEL_TYPE model_type_;
};

} // namespace gematria
Expand Down
2 changes: 1 addition & 1 deletion gematria/datasets/bhive_importer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class BHiveImporterTest : public ::testing::Test {
x86_canonicalizer_ =
std::make_unique<X86Canonicalizer>(&x86_llvm_->target_machine());
x86_bhive_importer_ =
std::make_unique<BHiveImporter>(x86_canonicalizer_.get());
std::make_unique<BHiveImporter>(x86_canonicalizer_.get(), "PER_FUNC_LIVE_INFO");
}

std::unique_ptr<LlvmArchitectureSupport> x86_llvm_;
Expand Down
14 changes: 7 additions & 7 deletions gematria/datasets/python/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ PYBIND11_MODULE(bhive_importer, m) {

py::class_<BHiveImporter>(m, "BHiveImporter")
.def( //
py::init<const Canonicalizer* /* canonicalizer */>(),
py::init<const Canonicalizer* /* canonicalizer */, const std::string&>(),
py::arg("canonicalizer"),
py::arg("model_type") = std::string("NO_LIVE_INFO"),
R"(Initializes a new BHive importer for a given architecture.
Args:
Expand Down Expand Up @@ -145,13 +146,12 @@ PYBIND11_MODULE(bhive_importer, m) {
py::arg("source_name"), py::arg("line"),py::arg("BB_name_index"), py::arg("throughput_column_index"),
py::arg("throughput_scaling") = 1.0, py::arg("base_address") = uint64_t{0},
R"(Creates a BasicBlockWithThroughputProto from a MIR CSV line.)"
).def(
"parse_interference_graph",
&BHiveImporter::InteferenceGraphParser, py::arg("file_name"),
R"(Parse the interference graph from a file)"
)
.def(
"BasicBlockProtoFromMBBName",
&BHiveImporter::BasicBlockProtoFromMBBName,
py::arg("MBB_name"), py::arg("base_address") = uint64_t{0},
R"(Creates a BasicBlockProto from a MIR CSV line.)"
);
;
}

} // namespace gematria
23 changes: 20 additions & 3 deletions gematria/datasets/python/import_from_mir.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@
'The name of directory containing all raw MIR files with performance throughput',
)

_MODEL_TYPE = flags.DEFINE_string(
'gematria_model_format',
None,
'The format of dataset to be imported. [NO_LIVE_INFO, PER_BB_LIVE_INFO, PER_FUNC_LIVE_INFO]',
)

_OUTPUT_TFRECORD_FILE = flags.DEFINE_string(
'gematria_output_tfrecord',
None,
Expand Down Expand Up @@ -111,6 +117,9 @@ def _validate_input_columns(flags_dict):
from pybind11_abseil import status
import tensorflow as tf

def is_mode_interference_graph(model_type):
return model_type == "PER_BB_LIVE_INFO" or model_type == "PER_FUNC_LIVE_INFO"


def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
Expand All @@ -129,7 +138,11 @@ def main(argv: Sequence[str]) -> None:
# LLVM triple. As of 2023-05, this is OK, because we support only x86-64
# anyway.
canonicalizer_obj = canonicalizer.Canonicalizer.x86_64(llvm)
importer = bhive_importer.BHiveImporter(canonicalizer_obj)
if is_mode_interference_graph(_MODEL_TYPE.value):
logging.info('Creating BHiveImporter with interference graph %s', _MODEL_TYPE.value)
importer = bhive_importer.BHiveImporter(canonicalizer_obj, _MODEL_TYPE.value)
else:
importer = bhive_importer.BHiveImporter(canonicalizer_obj)

with (
tf.io.TFRecordWriter(_OUTPUT_TFRECORD_FILE.value) as writer,
Expand All @@ -153,11 +166,15 @@ def main(argv: Sequence[str]) -> None:
mir_file = os.path.join(input_dir, filename)
print("mir file is " + mir_file)
perf_file = os.path.join(input_dir, filename.replace(".mir", ".perf"))
liveinfo_file = os.path.join(input_dir, filename + ".liveinfo")
try:
# load the MIR file
module = importer.LoadMIRModule(mir_file)
num_input_files += 1
logging.info('Procssing %s file', mir_file)
importer.LoadMIRModule(mir_file)
logging.info('Loading live info %s file', liveinfo_file)
# if is interference graph, then we need to load the liveinfo file
importer.parse_interference_graph(liveinfo_file)
num_input_files += 1
# iterate over each line in the corresponding .perf file
with tf.io.gfile.GFile(perf_file, 'r') as bhive_csv_file:
for line in bhive_csv_file:
Expand Down
Loading

0 comments on commit a51a79f

Please sign in to comment.