Skip to content

Commit

Permalink
predict: remove code for file evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
Mehrad0711 committed Feb 25, 2022
1 parent e07c8c5 commit e7db254
Showing 1 changed file with 30 additions and 78 deletions.
108 changes: 30 additions & 78 deletions genienlp/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
except RuntimeError:
pass

import sys

import torch

Expand All @@ -63,17 +62,17 @@
set_seed,
split_folder_on_disk,
)
from .validate import GenerationOutput, generate_with_model
from .validate import generate_with_model

logger = logging.getLogger(__name__)


def parse_argv(parser):
parser.add_argument('--path', type=str, required='--pred_file' not in sys.argv, help='Folder to load the model from')
parser.add_argument('--path', type=str, required=True, help='Folder to load the model from')
parser.add_argument(
'--evaluate',
type=str,
required='--pred_file' not in sys.argv,
required=True,
choices=['train', 'valid', 'test'],
help='Which dataset to do predictions for (train, dev or test)',
)
Expand Down Expand Up @@ -106,12 +105,6 @@ def parse_argv(parser):
parser.add_argument('--cache', default='.cache', type=str, help='where to save cached files')
parser.add_argument('--subsample', default=20000000, type=int, help='subsample the eval/test datasets')

parser.add_argument(
'--pred_file',
type=str,
help='If provided, we just compute evaluation metrics on this file and bypass model prediction. File should be in tsv format with id, pred, answer columns',
)

parser.add_argument(
'--pred_languages',
type=str,
Expand Down Expand Up @@ -564,20 +557,39 @@ def run(args, device):
prediction_file.write('\n'.join(lines) + '\n')

if len(generation_output.answers) > 0:
compute_metrics_on_file(
task_scores,
prediction_file_name,
results_file_name,
task,
metrics_to_compute = task.metrics
metrics_to_compute += args.extra_metrics
metrics_to_compute = [metric for metric in task.metrics if metric not in ['loss']]
if args.main_metric_only:
metrics_to_compute = [metrics_to_compute[0]]
metrics = calculate_and_reduce_metrics(
generation_output,
metrics_to_compute,
args,
tgt_lang,
confidence_scores=generation_output.confidence_scores,
)

log_final_results(args, task_scores)
with open(results_file_name, 'w' + ('' if args.overwrite else '+')) as results_file:
results_file.write(json.dumps(metrics) + '\n')

if not args.silent:
for i, (c, p, a) in enumerate(
zip(generation_output.contexts, generation_output.predictions, generation_output.answers)
):
log_string = (
f'\nContext {i + 1}: {c}\nPrediction {i + 1} ({len(p)} outputs): {p}\nAnswer {i + 1}: {a}\n'
)
if args.calibrator_paths is not None:
log_string += f'Confidence {i + 1} : '
for score in generation_output.confidence_scores:
log_string += f'{score[i]:.3f}, '
log_string += '\n'
logger.info(log_string)

logger.info(metrics)

task_scores[task].append((len(generation_output.answers), metrics[task.metrics[0]]))

def log_final_results(args, task_scores):
decaScore = []
for task in task_scores.keys():
decaScore.append(
Expand All @@ -592,55 +604,6 @@ def log_final_results(args, task_scores):
logger.info(f'\nSummary: | {sum(decaScore)} | {" | ".join([str(x) for x in decaScore])} |\n')


def compute_metrics_on_file(task_scores, pred_file, results_file_name, task, args, tgt_lang, confidence_scores=None):
generation_output = GenerationOutput()
ids, contexts, preds, targets = [], [], [], []
with open(pred_file) as fin:
for line in fin:
id_, *pred, target, context = line.strip('\n').split('\t')
ids.append(id_)
contexts.append(context)
preds.append(pred)
targets.append(target)

generation_output.example_ids = ids
generation_output.contexts = contexts
generation_output.predictions = preds
generation_output.answers = targets
generation_output.confidence_scores = confidence_scores

metrics_to_compute = task.metrics
metrics_to_compute += args.extra_metrics
metrics_to_compute = [metric for metric in task.metrics if metric not in ['loss']]
if args.main_metric_only:
metrics_to_compute = [metrics_to_compute[0]]
metrics = calculate_and_reduce_metrics(
generation_output,
metrics_to_compute,
args,
tgt_lang,
)

with open(results_file_name, 'w' + ('' if args.overwrite else '+')) as results_file:
results_file.write(json.dumps(metrics) + '\n')

if not args.silent:
for i, (c, p, a) in enumerate(
zip(generation_output.contexts, generation_output.predictions, generation_output.answers)
):
log_string = f'\nContext {i + 1}: {c}\nPrediction {i + 1} ({len(p)} outputs): {p}\nAnswer {i + 1}: {a}\n'
if args.calibrator_paths is not None:
log_string += f'Confidence {i + 1} : '
for score in generation_output.confidence_scores:
log_string += f'{score[i]:.3f}, '
log_string += '\n'
logger.info(log_string)

logger.info(metrics)

task_scores[task].append((len(generation_output.answers), metrics[task.metrics[0]]))


def main(args):
load_config_json(args)
check_and_update_generation_args(args)
Expand All @@ -667,17 +630,6 @@ def main(args):

task.metrics = new_metrics

if args.pred_file and os.path.exists(args.pred_file):
task_scores = defaultdict(list)
eval_dir = os.path.join(args.eval_dir, args.evaluate)
os.makedirs(eval_dir, exist_ok=True)
tgt_lang = args.pred_tgt_languages[0]
for task in args.tasks:
results_file_name = os.path.join(eval_dir, task.name + '.results.json')
compute_metrics_on_file(task_scores, args.pred_file, results_file_name, task, args, tgt_lang)
log_final_results(args, task_scores)
return

logger.info(f'Loading from {args.best_checkpoint}')
devices = get_devices(args.devices)

Expand Down

0 comments on commit e7db254

Please sign in to comment.