Skip to content

Commit

Permalink
Merge pull request #251 from stanford-oval/wip/dialogues
Browse files Browse the repository at this point in the history
Support end-to-end dialogue training and evaluation
  • Loading branch information
Mehrad0711 authored Feb 25, 2022
2 parents 5cec08d + e7db254 commit a5089ae
Show file tree
Hide file tree
Showing 28 changed files with 1,614 additions and 131 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ models/.DS_Store
src/
workdir/
*save*/
eval_dir/*
eval_dir*/*
genieNLP-tests*

lightning_logs/
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
rev: v4.1.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
exclude: ^(tests/dataset/|tests/database/|tests/expected_results/)
- id: trailing-whitespace
exclude: ^(tests/dataset/|tests/database/|tests/expected_results/)
- repo: https://github.com/hadialqattan/pycln
rev: v1.0.3
rev: v1.2.1
hooks:
- id: pycln
args: [--config=pyproject.toml]
- repo: https://github.com/PyCQA/isort
rev: 5.9.3
rev: 5.10.1
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 21.9b0
rev: 22.1.0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
Expand Down
5 changes: 5 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ jobs:
stage: test
script:
- bash ./tests/test_translation.sh
-
name: "E2E Dialogues tests"
stage: test
script:
- bash ./tests/test_e2e_dialogues.sh
-
name: "NED tests"
stage: test
Expand Down
42 changes: 42 additions & 0 deletions genienlp/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def parse_argv(parser):
default=0.1,
help='multiplicative constant choosing the weight of encoder_loss in total loss',
)
parser.add_argument('--train_set_name', type=str, default='train', help='Training dataset name to use during training')
parser.add_argument('--eval_set_name', type=str, help='Evaluation dataset name to use during training')

parser.add_argument('--max_output_length', default=150, type=int, help='maximum output length for generation')
Expand Down Expand Up @@ -544,6 +545,33 @@ def parse_argv(parser):
help='Debugging flag for hf datasets where validation will be performed on train set',
)

parser.add_argument(
'--e2e_dialogue_evaluation',
action='store_true',
help='Evaluate model on a dialogue dataset end-to-end; i.e. model predictions are used as input instead of gold',
)
parser.add_argument(
'--e2e_dialogue_valid_subtasks',
nargs='+',
type=str,
default=['dst', 'api', 'da'],
help='Evaluate only on these subtasks when calculating e2e_dialogue_score; rg is not included by default',
)
parser.add_argument(
'--e2e_dialogue_valid_submetrics',
nargs='+',
type=str,
default=['jga', 'em', 'em'],
help='Specify metrics to use for each of subtasks in e2e_dialogue_valid_subtasks.',
)
parser.add_argument(
'--e2e_dialogue_valid_subweights',
nargs='+',
type=float,
default=[1.0, 1.0, 1.0],
help='Specify weights to use for each of subtasks in e2e_dialogue_valid_subtasks.',
)


def check_and_update_generation_args(args):
"""
Expand Down Expand Up @@ -632,6 +660,20 @@ def post_parse_general(args):


def post_parse_train_specific(args):
if args.e2e_dialogue_evaluation and args.val_batch_size[0] != 1:
logger.warning('When evaluating bitod end2end val_batch_size should be 1 so we load data turn by turn')
args.val_batch_size = [1]

if len(args.e2e_dialogue_valid_subtasks) != len(args.e2e_dialogue_valid_submetrics):
raise ValueError(
'Length of e2e_dialogue_valid_subtasks and e2e_dialogue_valid_submetrics arguments should be equal (i.e. one metric per subtask)'
)

if len(args.e2e_dialogue_valid_subtasks) != len(args.e2e_dialogue_valid_subweights):
raise ValueError(
'Length of e2e_dialogue_valid_subtasks and e2e_dialogue_valid_subweights arguments should be equal (i.e. one weight per subtask)'
)

if len(args.val_batch_size) < len(args.val_task_names):
args.val_batch_size = len(args.val_task_names) * args.val_batch_size

Expand Down
4 changes: 0 additions & 4 deletions genienlp/data_utils/numericalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,6 @@ def build_vocab(self, vocab_sets, tasks):
if self.args.add_entities_to_text != 'off':
self._tokenizer.add_tokens(['<e>', '</e>'])

# add special tokens for ambig_qa task
if any(task.name == 'ambig_qa' for task in tasks):
self._tokenizer.add_tokens(['<q>', '<p>', '<u>'])

existing_special_tokens = self._tokenizer.special_tokens_map
# add separator if it doesn't exist. It will be used to concatenate context and question
if 'sep_token' not in existing_special_tokens:
Expand Down
143 changes: 121 additions & 22 deletions genienlp/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import collections
import logging
import os
import re
import string
import logging
from collections import Counter, OrderedDict, defaultdict
from contextlib import closing
from multiprocessing import Pool, cpu_count
from subprocess import PIPE, Popen
Expand All @@ -40,6 +40,8 @@
import numpy as np
import sacrebleu
from datasets import load_metric
from dialogues import Bitod
from dialogues.bitod.src.evaluate import convert_lists_to_set
from pyrouge import Rouge155
from seqeval import metrics as seq_metrics
from seqeval import scheme as seq_scheme
Expand All @@ -51,7 +53,8 @@

# metrics that are calculated over a corpus (i.e. a list of predictions and gold answers, not single ones).
# These metrics cannot be calculated on individual examples and then averaged.
corpus_level_metrics = set(['bleu', 'casedbleu', 'ter', 't5_bleu', 'nmt_bleu', 'corpus_f1'])
corpus_level_metrics = {'bleu', 'casedbleu', 'ter', 't5_bleu', 'nmt_bleu', 'corpus_f1', 'jga'}


def to_lf(s, table):
aggs = [y.lower() for y in Query.agg_ops]
Expand Down Expand Up @@ -218,7 +221,7 @@ def lower(text):
def f1_score(prediction, ground_truth):
prediction_tokens = prediction.split()
ground_truth_tokens = ground_truth.split()
common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens)
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
Expand Down Expand Up @@ -519,7 +522,81 @@ def computeDialogue(greedy, answer):
return joint_goal_em, turn_request_em, turn_goal_em, answer


def compute_metrics(predictions: Iterable[str], answers: Union[Iterable[str], Iterable[Iterable[str]]], requested_metrics: Iterable, lang: str):
def compute_e2e_dialogue_score(greedy, answer, tgt_lang, args, example_ids):
num_examples = len(answer)
subtask_metrics_dict = OrderedDict()

results = OrderedDict({'e2e_dialogue_score': 0.0, 'JGA': 0.0, 'API_em': 0.0, 'DA_em': 0.0, 'BLEU': 0.0})
subtask2result_key = OrderedDict({'dst': 'JGA', 'api': 'API_em', 'da': 'DA_em', 'rg': 'BLEU'})

for k, subtask in enumerate(args.e2e_dialogue_valid_subtasks):
ids, preds, golds = [], [], []
for i in range(num_examples):
id_ = example_ids[i]
if id_.endswith(f'/{subtask}'):
ids.append(id_)
preds.append(greedy[i])
golds.append(answer[i])

if golds:
metrics_to_compute = args.e2e_dialogue_valid_submetrics[k]
sub_metrics = compute_metrics(preds, golds, [metrics_to_compute], tgt_lang, args, ids)
subtask_metrics_dict[subtask] = (
sub_metrics[metrics_to_compute],
len(golds),
args.e2e_dialogue_valid_subweights[k],
)

# TODO how should we aggregate?
weighted_num_examples = 0
for subtask, (sub_result, num_ex, weight) in subtask_metrics_dict.items():
result_key = subtask2result_key[subtask]

results[result_key] += sub_result
results['e2e_dialogue_score'] += weight * (sub_result * num_ex)
weighted_num_examples += weight * num_ex

results['e2e_dialogue_score'] /= weighted_num_examples

return results


def computeJGA(greedy, answer, example_ids):
dataset = Bitod()
hit = 0
cur_dial_id = None
assert len(example_ids) == len(greedy) == len(answer)
for id_, g, a in zip(example_ids, greedy, answer):
dial_id = id_.split('/')[1]
if dial_id != cur_dial_id:
cur_dial_id = dial_id
greedy_state = defaultdict()
answer_state = defaultdict()

a = a[0]
a = dataset.span2state(a)
g = dataset.span2state(g)

dataset.update_state(a, answer_state)
dataset.update_state(g, greedy_state)

convert_lists_to_set(answer_state)
convert_lists_to_set(greedy_state)

if answer_state == greedy_state:
hit += 1

return hit / len(greedy) * 100


def compute_metrics(
predictions: Iterable[str],
answers: Union[Iterable[str], Iterable[Iterable[str]]],
requested_metrics: Iterable,
lang: str,
args,
example_ids: Iterable[str] = None,
):
"""
Inputs:
predictions: a list of model predictions
Expand All @@ -536,11 +613,22 @@ def compute_metrics(predictions: Iterable[str], answers: Union[Iterable[str], It
lfem
joint_goal_em, turn_request_em, turn_goal_em, avg_dialogue
lang: the language of the predictions and answers. Used for BERTScore.
args: arguments
example_ids: used to calculate some of e2e dialogue metrics that need to know span of each dialogue such as JGA
"""
metric_keys = []
metric_values = []
if not isinstance(answers[0], list):
answers = [[a] for a in answers]
if 'e2e_dialogue_score' in requested_metrics:
requested_metrics += ['JGA', 'API_em', 'DA_em', 'BLEU']
results = compute_e2e_dialogue_score(predictions, answers, lang, args, example_ids)
metric_keys += results.keys()
metric_values += results.values()
if 'jga' in requested_metrics:
jga = computeJGA(predictions, answers, example_ids)
metric_keys += ['jga']
metric_values += [jga]
if 'lfem' in requested_metrics:
lfem, answers = computeLFEM(predictions, answers)
metric_keys += ['lfem']
Expand All @@ -550,9 +638,10 @@ def compute_metrics(predictions: Iterable[str], answers: Union[Iterable[str], It
avg_dialogue = (joint_goal_em + request_em) / 2
metric_keys += ['joint_goal_em', 'turn_request_em', 'turn_goal_em', 'avg_dialogue']
metric_values += [joint_goal_em, request_em, turn_goal_em, avg_dialogue]
em = computeEM(predictions, answers)
metric_keys += ['em']
metric_values += [em]
if 'em' in requested_metrics:
em = computeEM(predictions, answers)
metric_keys += ['em']
metric_values += [em]
if 'pem' in requested_metrics:
pem = computePartialEM(predictions, answers)
metric_keys.append('pem')
Expand Down Expand Up @@ -621,7 +710,8 @@ def convert_IOB2_to_IOB1(labels):
convert_IOB2_to_IOB1(predictions_processed)
convert_IOB2_to_IOB1(answers_processed)
f1 = (
seq_metrics.f1_score(y_pred=predictions_processed, y_true=answers_processed, mode='strict', scheme=seq_scheme.IOB1) * 100
seq_metrics.f1_score(y_pred=predictions_processed, y_true=answers_processed, mode='strict', scheme=seq_scheme.IOB1)
* 100
)

metric_keys.append('ner_f1_IOB1')
Expand Down Expand Up @@ -653,26 +743,35 @@ def convert_IOB2_to_IOB1(labels):
metric_values += [corpus_f1, precision, recall]

metric_dict = dict(zip(metric_keys, metric_values))
metric_dict = collections.OrderedDict((key, metric_dict[key]) for key in requested_metrics)
metric_dict = OrderedDict((key, metric_dict[key]) for key in requested_metrics)
return metric_dict


def calculate_and_reduce_metrics(predictions, answers, metrics_to_compute, reduce_metrics, lang):
metrics = collections.OrderedDict()
if reduce_metrics == 'max':
for i in range(len(predictions[0])): # for each output (in case of mulitple outputs)
partial_metrics = compute_metrics([p[i] for p in predictions], answers, metrics_to_compute, lang) # calculate the metric on all first outputs, all second outputs, etc.
def calculate_and_reduce_metrics(generation_output, metrics_to_compute, args, lang):
metrics = OrderedDict()
example_ids = generation_output.example_ids
predictions = generation_output.predictions
answers = generation_output.answers

if args.reduce_metrics == 'max':
for i in range(len(predictions[0])): # for each output (in case of multiple outputs)
partial_metrics = compute_metrics(
[p[i] for p in predictions], answers, metrics_to_compute, lang, args, example_ids
) # calculate the metric on all first outputs, all second outputs, etc.
for k, v in partial_metrics.items():
metrics[k] = max(metrics.get(k, 0), v)
elif reduce_metrics == 'top_k':
elif args.reduce_metrics == 'top_k':
for m in metrics_to_compute:
if m in corpus_level_metrics:
logging.warning('You are using the corpus-level metric %s with `--reduce_metrics top_k`, which can lead to incorrect results.', m)

for i in range(len(predictions)): # for each input
example_metrics = collections.OrderedDict() # keep track of metrics for one input and all of its outputs
for j in range(len(predictions[i])): # for each output (in case of mulitple outputs)
partial_metrics = compute_metrics([predictions[i][j]], [answers[i]], metrics_to_compute, lang) # calculate the metric on the j-th output of the i-th input
logging.warning(
f'You are using the corpus-level metric {m} with `--reduce_metrics top_k`, which can lead to incorrect results.',
)
for i in range(len(predictions)): # for each input
example_metrics = OrderedDict() # keep track of metrics for one input and all of its outputs
for j in range(len(predictions[i])): # for each output (in case of multiple outputs)
partial_metrics = compute_metrics(
[predictions[i][j]], [answers[i]], metrics_to_compute, lang, args, example_ids
) # calculate the metric on the j-th output of the i-th input
for k, v in partial_metrics.items():
example_metrics[k] = max(example_metrics.get(k, 0), v)
# sum metrics for all examples
Expand Down
2 changes: 1 addition & 1 deletion genienlp/paraphrase/scripts/transform_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def main(args):
new_queries = [] # list of lists
query_file = open(args.query_file, 'r')
for line in query_file:
queries = line.split('\t')[1:-1] # 0 is example id, -1 is gold answer
queries = line.split('\t')[1:-2] # 0 is example id, -1 is input, -2 is gold answer
new_queries.append([lower_case(tokenize(q.strip())) for q in queries])
if args.transformation in ['remove_wrong_thingtalk', 'get_wrong_thingtalk']:
gold_thingtalks = []
Expand Down
Loading

0 comments on commit a5089ae

Please sign in to comment.