From 8eb20ebf0b3011ffae8878897a610ac87b519495 Mon Sep 17 00:00:00 2001 From: lukezhuz Date: Sat, 2 Dec 2023 17:35:52 -0500 Subject: [PATCH] Small fixes for training --- gematria/datasets/python/import_from_mir.py | 113 +++++++++++--------- gematria/granite/graph_builder.cc | 27 +++-- gematria/io/python/utils.py | 16 +++ gematria/model/python/main_function.py | 5 + 4 files changed, 96 insertions(+), 65 deletions(-) diff --git a/gematria/datasets/python/import_from_mir.py b/gematria/datasets/python/import_from_mir.py index 76306742..aec93aeb 100644 --- a/gematria/datasets/python/import_from_mir.py +++ b/gematria/datasets/python/import_from_mir.py @@ -42,6 +42,13 @@ 'The name of directory containing all raw MIR files with performance throughput', required=True, ) + +_INPUT_DIR2 = flags.DEFINE_string( + 'gematria_input_dir2', + None, + 'The name of directory containing all raw MIR files with performance throughput', +) + _OUTPUT_TFRECORD_FILE = flags.DEFINE_string( 'gematria_output_tfrecord', None, @@ -132,58 +139,62 @@ def main(argv: Sequence[str]) -> None: num_input_files = 0 num_skipped_blocks = 0 num_skipped_files = 0 - for filename in os.listdir(_INPUT_DIR.value): - if filename.endswith(".mir"): - if num_input_files % 1000 == 0: - logging.info( - 'Processed %d files, skipped %d.', - num_input_files, - num_skipped_files, - ) - mir_file = os.path.join(_INPUT_DIR.value, filename) - print("mir file is " + mir_file) - perf_file = os.path.join(_INPUT_DIR.value, filename.replace(".mir", ".perf")) - try: - # load the MIR file - module = importer.LoadMIRModule(mir_file) - num_input_files += 1 - logging.info('Procssing %s file', mir_file) - # iterate over each line in the corresponding .perf file - with tf.io.gfile.GFile(perf_file, 'r') as bhive_csv_file: - for line in bhive_csv_file: - if num_input_blocks % 1000 == 0: - logging.info( - 'Processed %d blocks, skipped %d.', - num_input_blocks, - num_skipped_blocks, - ) - num_input_blocks += 1 - try: - hex = line.split(",")[_MACHINE_HEX_COLUMN_INDEX.value] - BB_name = line.split(",")[_MACHINE_BASIC_BLOCK_NAME_COLUMN_INDEX.value] - through_put = line.split(",")[_THROUGHPUT_COLUMN_INDEX.value] - # skip blocks with throughput -1 - if float(through_put) == -1: - num_skipped_blocks += 1 - continue - # skip blocks with duplicate machine code - if hex in machine_hex_set: + input_dirs = [_INPUT_DIR.value] + if _INPUT_DIR2.value: + input_dirs.append(_INPUT_DIR2.value) + for input_dir in input_dirs: + for filename in os.listdir(input_dir): + if filename.endswith(".mir"): + if num_input_files % 1000 == 0: + logging.info( + 'Processed %d files, skipped %d.', + num_input_files, + num_skipped_files, + ) + mir_file = os.path.join(_INPUT_DIR.value, filename) + print("mir file is " + mir_file) + perf_file = os.path.join(_INPUT_DIR.value, filename.replace(".mir", ".perf")) + try: + # load the MIR file + module = importer.LoadMIRModule(mir_file) + num_input_files += 1 + logging.info('Procssing %s file', mir_file) + # iterate over each line in the corresponding .perf file + with tf.io.gfile.GFile(perf_file, 'r') as bhive_csv_file: + for line in bhive_csv_file: + if num_input_blocks % 1000 == 0: + logging.info( + 'Processed %d blocks, skipped %d.', + num_input_blocks, + num_skipped_blocks, + ) + num_input_blocks += 1 + try: + hex = line.split(",")[_MACHINE_HEX_COLUMN_INDEX.value] + BB_name = line.split(",")[_MACHINE_BASIC_BLOCK_NAME_COLUMN_INDEX.value] + through_put = line.split(",")[_THROUGHPUT_COLUMN_INDEX.value] + # skip blocks with throughput -1 + if float(through_put) == -1: + num_skipped_blocks += 1 + continue + # skip blocks with duplicate machine code + if hex in machine_hex_set: + num_skipped_blocks += 1 + continue + machine_hex_set.add(hex) + block_proto = importer.ParseMIRCsvLine( + source_name=_SOURCE_NAME.value, + line=line.strip(), + BB_name_index = _MACHINE_BASIC_BLOCK_NAME_COLUMN_INDEX.value, + throughput_column_index = _THROUGHPUT_COLUMN_INDEX.value, + throughput_scaling=_THROUGHPUT_SCALING.value, + ) + writer.write(block_proto.SerializeToString()) + except: num_skipped_blocks += 1 - continue - machine_hex_set.add(hex) - block_proto = importer.ParseMIRCsvLine( - source_name=_SOURCE_NAME.value, - line=line.strip(), - BB_name_index = _MACHINE_BASIC_BLOCK_NAME_COLUMN_INDEX.value, - throughput_column_index = _THROUGHPUT_COLUMN_INDEX.value, - throughput_scaling=_THROUGHPUT_SCALING.value, - ) - writer.write(block_proto.SerializeToString()) - except: - num_skipped_blocks += 1 - except: - logging.exception('Could not load file "%s"', mir_file) - num_skipped_files += 1 + except: + logging.exception('Could not load file "%s"', mir_file) + num_skipped_files += 1 logging.info( 'Processed %d files, skipped %d.', num_input_files, diff --git a/gematria/granite/graph_builder.cc b/gematria/granite/graph_builder.cc index 4c7fef1b..4393af01 100644 --- a/gematria/granite/graph_builder.cc +++ b/gematria/granite/graph_builder.cc @@ -29,8 +29,6 @@ #include "gematria/basic_block/basic_block.h" #include "gematria/model/oov_token_behavior.h" -#define DEBUG - #ifdef DEBUG #define LOG(X) \ std::cerr << X << "\n" @@ -299,32 +297,33 @@ bool BasicBlockGraphBuilder::AddInputOperand( const AddressTuple& address_tuple = operand.address(); if (!address_tuple.base_register.empty()) { bool is_virtual_reg = address_tuple.base_register[0] == '%'; - LOG("base register: " << address_tuple.base_register << " is virtual: " << is_virtual_reg); std::string vreg_token = getVREG_TOKEN(64); - if (!AddDependencyOnRegister(address_node, address_tuple.base_register, - vreg_token, - EdgeType::kAddressBaseRegister)) { + bool result = AddDependencyOnRegister(address_node, address_tuple.base_register, + is_virtual_reg ? vreg_token : address_tuple.base_register, + EdgeType::kAddressBaseRegister); + if (result == false) { return false; } } if (!address_tuple.index_register.empty()) { bool is_virtual_reg = address_tuple.base_register[0] == '%'; - LOG("index register: " << address_tuple.base_register << " is virtual: " << is_virtual_reg); std::string vreg_token = getVREG_TOKEN(64); - if (!AddDependencyOnRegister(address_node, address_tuple.index_register, - vreg_token, - EdgeType::kAddressIndexRegister)) { + bool result = AddDependencyOnRegister(address_node, + address_tuple.index_register, + is_virtual_reg ? vreg_token : address_tuple.index_register, + EdgeType::kAddressIndexRegister); + if (result == false) { return false; } } if (!address_tuple.segment_register.empty()) { bool is_virtual_reg = address_tuple.base_register[0] == '%'; - LOG("index register: " << address_tuple.base_register << " is virtual: " << is_virtual_reg); std::string vreg_token = getVREG_TOKEN(64); - if (!AddDependencyOnRegister(address_node, + bool result = AddDependencyOnRegister(address_node, address_tuple.segment_register, - vreg_token, - EdgeType::kAddressSegmentRegister)) { + is_virtual_reg ? vreg_token : address_tuple.segment_register, + EdgeType::kAddressSegmentRegister); + if (result == false) { return false; } } diff --git a/gematria/io/python/utils.py b/gematria/io/python/utils.py index 5035991f..2b404f38 100644 --- a/gematria/io/python/utils.py +++ b/gematria/io/python/utils.py @@ -162,6 +162,22 @@ def _scale_values( for i in range(len(values)): values[i] *= scaling_factor +def drop_blocks_with_empty_instructions( + block: throughput_pb2.BasicBlockWithThroughputProto +) -> Optional[throughput_pb2.BasicBlockWithThroughputProto]: + """Removes basic blocks that do not have any instructions. + + Returns `block` unchanged if it has at least one instruction. + + Args: + block: The basic block proto to inspect. + + Returns: + None when `block` has no instructions; otherwise, returns `block`. + """ + if block.basic_block.canonicalized_instructions: + return block + return None def drop_blocks_with_no_throughputs( use_prefixes: bool, block: throughput_pb2.BasicBlockWithThroughputProto diff --git a/gematria/model/python/main_function.py b/gematria/model/python/main_function.py index df7989f3..255a4d97 100644 --- a/gematria/model/python/main_function.py +++ b/gematria/model/python/main_function.py @@ -715,6 +715,11 @@ def _make_basic_block_reader_from_command_line_flags( proto_filters.append( functools.partial(utils.aggregate_throughputs, throughput_selection) ) + proto_filters.append( + functools.partial( + utils.drop_blocks_with_empty_instructions, + ) + ) if _INPUT_FILE_SCALING.value != 1.0: proto_filters.append( functools.partial(utils.scale_throughputs, _INPUT_FILE_SCALING.value)