Skip to content

Commit

Permalink
predict: move loop outside of create_output_lines
Browse files Browse the repository at this point in the history
  • Loading branch information
Mehrad0711 committed Feb 28, 2022
1 parent 96f6db7 commit 1c301f2
Showing 1 changed file with 34 additions and 29 deletions.
63 changes: 34 additions & 29 deletions genienlp/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,31 +383,34 @@ def prepare_data_iterators(args, val_sets, numericalizer, device):
return iters


def create_output_line(args, generation_output):
lines = []
for i in range(len(generation_output.example_ids)):
predictions = generation_output.raw_predictions if args.translate_return_raw_outputs else generation_output.predictions
if args.one_output_per_line:
lines = [
'\t'.join(
[generation_output.example_ids[i], prediction, generation_output.answers[i], generation_output.contexts[i]]
)
for prediction in predictions[i]
] # one line per generation output
else:
lines = [
'\t'.join(
[
generation_output.example_ids[i],
*predictions[i],
generation_output.answers[i],
generation_output.contexts[i],
]
)
] # one line with all generation outputs separated by '\t'
if args.calibrator_paths is not None:
for score in generation_output.confidence_scores:
lines = [line + '\t' + str(score[i]) for line in lines] # append score to all lines
def create_output_lines(args, index, generation_output):
predictions = generation_output.raw_predictions if args.translate_return_raw_outputs else generation_output.predictions
if args.one_output_per_line:
lines = [
'\t'.join(
[
generation_output.example_ids[index],
prediction,
generation_output.answers[index],
generation_output.contexts[index],
]
)
for prediction in predictions[index]
] # one line per generation output
else:
lines = [
'\t'.join(
[
generation_output.example_ids[index],
*predictions[index],
generation_output.answers[index],
generation_output.contexts[index],
]
)
] # one line with all generation outputs separated by '\t'
if args.calibrator_paths is not None:
for score in generation_output.confidence_scores:
lines = [line + '\t' + str(score[index]) for line in lines] # append score to all lines
return lines


Expand Down Expand Up @@ -490,13 +493,15 @@ def run(args, device):
# write into file
# TODO change to jsonl format
with open(prediction_file_name, 'w' + ('' if args.overwrite else '+')) as prediction_file:
lines = create_output_line(args, generation_output)
prediction_file.write('\n'.join(lines) + '\n')
for i in range(len(generation_output.example_ids)):
lines = create_output_lines(args, i, generation_output)
prediction_file.write('\n'.join(lines) + '\n')

if args.translate_return_raw_outputs:
with open(raw_prediction_file_name, 'w' + ('' if args.overwrite else '+')) as prediction_file:
lines = create_output_line(args, generation_output)
prediction_file.write('\n'.join(lines) + '\n')
for i in range(len(generation_output.example_ids)):
lines = create_output_lines(args, i, generation_output)
prediction_file.write('\n'.join(lines) + '\n')

if len(generation_output.answers) > 0:
metrics_to_compute = get_metrics_to_compute(args, task)
Expand Down

0 comments on commit 1c301f2

Please sign in to comment.