Skip to content

Commit

Permalink
Fixed import_from_mir issue
Browse files Browse the repository at this point in the history
  • Loading branch information
9Tempest committed Nov 16, 2023
1 parent d7fb212 commit 6ec1ae5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 24 deletions.
6 changes: 6 additions & 0 deletions gematria/datasets/python/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ PYBIND11_MODULE(bhive_importer, m) {
py::arg("source_name"), py::arg("line"),py::arg("BB_name_index"), py::arg("throughput_column_index"),
py::arg("throughput_scaling") = 1.0, py::arg("base_address") = uint64_t{0},
R"(Creates a BasicBlockWithThroughputProto from a MIR CSV line.)"
)
.def(
"BasicBlockProtoFromMBBName",
&BHiveImporter::BasicBlockProtoFromMBBName,
py::arg("MBB_name"), py::arg("base_address") = uint64_t{0},
R"(Creates a BasicBlockProto from a MIR CSV line.)"
);
}

Expand Down
61 changes: 37 additions & 24 deletions gematria/datasets/python/import_from_mir.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,36 +122,49 @@ def main(argv: Sequence[str]) -> None:
tf.io.TFRecordWriter(_OUTPUT_TFRECORD_FILE.value) as writer,
):
num_input_blocks = 0
num_input_files = 0
num_skipped_blocks = 0
num_skipped_files = 0
for filename in os.listdir(_INPUT_DIR.value):
if filename.endswith(".mir"):
if num_input_files % 1000 == 0:
logging.info(
'Processed %d files, skipped %d.',
num_input_files,
num_skipped_files,
)
mir_file = os.path.join(_INPUT_DIR.value, filename)
print("mir file is " + mir_file)
perf_file = os.path.join(_INPUT_DIR.value, filename.replace(".mir", ".ll.perf"))
# load the MIR file
module = importer.LoadMIRModule(mir_file)
logging.info('Procssing %s file', mir_file)
# iterate over each line in the corresponding .perf file
with tf.io.gfile.GFile(perf_file, 'r') as bhive_csv_file:
for line in bhive_csv_file:
if num_input_blocks % 1000 == 0:
logging.info(
'Processed %d blocks, skipped %d.',
num_input_blocks,
num_skipped_blocks,
)
num_input_blocks += 1
try:
block_proto = importer.basic_block_with_throughput_proto_from_MBB_name(
source_name=_SOURCE_NAME.value,
line=line.strip(),
BB_name_index = _MACHINE_BASIC_BLOCK_NAME_COLUMN_INDEX.value,
throughput_column_index = _THROUGHPUT_COLUMN_INDEX.value,
throughput_scaling=_THROUGHPUT_SCALING.value,
)
writer.write(block_proto.SerializeToString())
except bhive_importer.BasicBlockNotFoundError:
num_skipped_blocks += 1
try:
# load the MIR file
module = importer.LoadMIRModule(mir_file)
num_skipped_files += 1
logging.info('Procssing %s file', mir_file)
# iterate over each line in the corresponding .perf file
with tf.io.gfile.GFile(perf_file, 'r') as bhive_csv_file:
for line in bhive_csv_file:
if num_input_blocks % 1000 == 0:
logging.info(
'Processed %d blocks, skipped %d.',
num_input_blocks,
num_skipped_blocks,
)
num_input_blocks += 1
try:
block_proto = importer.ParseMIRCsvLine(
source_name=_SOURCE_NAME.value,
line=line.strip(),
BB_name_index = _MACHINE_BASIC_BLOCK_NAME_COLUMN_INDEX.value,
throughput_column_index = _THROUGHPUT_COLUMN_INDEX.value,
throughput_scaling=_THROUGHPUT_SCALING.value,
)
writer.write(block_proto.SerializeToString())
except bhive_importer.BasicBlockNotFoundError:
num_skipped_blocks += 1
except:
logging.exception('Could not load file "%s"', mir_file)
num_skipped_files += 1


if __name__ == '__main__':
Expand Down

0 comments on commit 6ec1ae5

Please sign in to comment.