diff --git a/testdata/generate_training_data.py b/testdata/generate_training_data.py index 8c0cbf318..6d64d0600 100755 --- a/testdata/generate_training_data.py +++ b/testdata/generate_training_data.py @@ -61,7 +61,6 @@ def FindELFTrainingFiles(): """ Returns the list of ELF files that should be used for training. These ELF files need to contain objdump-able debug information. - """ elf_files = [ filename for filename in glob.iglob( FLAGS.executable_directory + 'ELF/**/*', recursive=True) @@ -350,7 +349,6 @@ def IndexToRowColumn(index, n): """ Given an index into the non-zero elements of an upper triangular matrix, returns a tuple of integers indicating the row, column of that entry. - n is the number of elements in the family we are dealing with. """ if n & 1: @@ -462,6 +460,8 @@ def GenerateRepulsionPairs( input_map, number_of_pairs ): while len(repulsion_set) != number_of_pairs and max_loop_iterations > 0: symbol_one, symbol_two = numpy.random.choice( symbols_as_list, 2, replace=False ) + if (ExtractFunctionName(symbol_one) == ExtractFunctionName(symbol_two)): + continue element_one = random.choice( input_map[symbol_one] ) element_two = random.choice( input_map[symbol_two] ) ordered_pair = tuple(sorted([element_one, element_two])) @@ -469,6 +469,13 @@ def GenerateRepulsionPairs( input_map, number_of_pairs ): max_loop_iterations = max_loop_iterations - 1 return repulsion_set +def ExtractFunctionName(symbol): + decoded_string = subprocess.run(["base64", "-d"], stdout=PIPE, + input=bytes(symbol, encoding="utf-8")).stdout.decode("utf-8") + decoded_string = decoded_string.split('(')[0] + return decoded_string + + def WritePairsFile( set_of_pairs, output_name ): """ Take a set of pairs ((file_idA, addressA), (file_idB, addressB)) and write them @@ -527,7 +534,6 @@ def WriteSeenTrainingAndValidationData(symbol_to_file_and_address, FLAGS): Remove random element R for the validation set Generate all pairs of attraction for the family without R (training) Generate all pairs of attraction between family members and R (validation) - Now generate as many random repulsion pairs. """ training_attraction_set = set() @@ -611,7 +617,7 @@ def main(argv): # First, generate the training and validation data for performance on unseen # functions - to test how well we generalize beyond things we have already # seen variants of. - WriteUnseenTrainingAndValidationData(symbol_to_files_and_address, FLAGS) +WriteUnseenTrainingAndValidationData(symbol_to_files_and_address, FLAGS) # Secondly, generate the training and validation data for performance on 'seen' # functions -- e.g. how well we perform if we need to spot a variant of a function @@ -622,4 +628,3 @@ def main(argv): if __name__ == '__main__': app.run(main) -