diff --git a/gematria/datasets/python/bhive_importer.cc b/gematria/datasets/python/bhive_importer.cc index 5c5f1e08..1ba2c324 100644 --- a/gematria/datasets/python/bhive_importer.cc +++ b/gematria/datasets/python/bhive_importer.cc @@ -145,6 +145,12 @@ PYBIND11_MODULE(bhive_importer, m) { py::arg("source_name"), py::arg("line"),py::arg("BB_name_index"), py::arg("throughput_column_index"), py::arg("throughput_scaling") = 1.0, py::arg("base_address") = uint64_t{0}, R"(Creates a BasicBlockWithThroughputProto from a MIR CSV line.)" + ) + .def( + "BasicBlockProtoFromMBBName", + &BHiveImporter::BasicBlockProtoFromMBBName, + py::arg("MBB_name"), py::arg("base_address") = uint64_t{0}, + R"(Creates a BasicBlockProto from a MIR CSV line.)" ); } diff --git a/gematria/datasets/python/import_from_mir.py b/gematria/datasets/python/import_from_mir.py index d76e9ee5..afb10bfc 100644 --- a/gematria/datasets/python/import_from_mir.py +++ b/gematria/datasets/python/import_from_mir.py @@ -122,36 +122,49 @@ def main(argv: Sequence[str]) -> None: tf.io.TFRecordWriter(_OUTPUT_TFRECORD_FILE.value) as writer, ): num_input_blocks = 0 + num_input_files = 0 num_skipped_blocks = 0 + num_skipped_files = 0 for filename in os.listdir(_INPUT_DIR.value): if filename.endswith(".mir"): + if num_input_files % 1000 == 0: + logging.info( + 'Processed %d files, skipped %d.', + num_input_files, + num_skipped_files, + ) mir_file = os.path.join(_INPUT_DIR.value, filename) print("mir file is " + mir_file) perf_file = os.path.join(_INPUT_DIR.value, filename.replace(".mir", ".ll.perf")) - # load the MIR file - module = importer.LoadMIRModule(mir_file) - logging.info('Procssing %s file', mir_file) - # iterate over each line in the corresponding .perf file - with tf.io.gfile.GFile(perf_file, 'r') as bhive_csv_file: - for line in bhive_csv_file: - if num_input_blocks % 1000 == 0: - logging.info( - 'Processed %d blocks, skipped %d.', - num_input_blocks, - num_skipped_blocks, - ) - num_input_blocks += 1 - try: - block_proto = importer.basic_block_with_throughput_proto_from_MBB_name( - source_name=_SOURCE_NAME.value, - line=line.strip(), - BB_name_index = _MACHINE_BASIC_BLOCK_NAME_COLUMN_INDEX.value, - throughput_column_index = _THROUGHPUT_COLUMN_INDEX.value, - throughput_scaling=_THROUGHPUT_SCALING.value, - ) - writer.write(block_proto.SerializeToString()) - except bhive_importer.BasicBlockNotFoundError: - num_skipped_blocks += 1 + try: + # load the MIR file + module = importer.LoadMIRModule(mir_file) + num_skipped_files += 1 + logging.info('Procssing %s file', mir_file) + # iterate over each line in the corresponding .perf file + with tf.io.gfile.GFile(perf_file, 'r') as bhive_csv_file: + for line in bhive_csv_file: + if num_input_blocks % 1000 == 0: + logging.info( + 'Processed %d blocks, skipped %d.', + num_input_blocks, + num_skipped_blocks, + ) + num_input_blocks += 1 + try: + block_proto = importer.ParseMIRCsvLine( + source_name=_SOURCE_NAME.value, + line=line.strip(), + BB_name_index = _MACHINE_BASIC_BLOCK_NAME_COLUMN_INDEX.value, + throughput_column_index = _THROUGHPUT_COLUMN_INDEX.value, + throughput_scaling=_THROUGHPUT_SCALING.value, + ) + writer.write(block_proto.SerializeToString()) + except bhive_importer.BasicBlockNotFoundError: + num_skipped_blocks += 1 + except: + logging.exception('Could not load file "%s"', mir_file) + num_skipped_files += 1 if __name__ == '__main__':