From 6ff6b5c1c69d5e0f4c9228753e63e96f8ee0853e Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 12 Aug 2021 19:22:09 -0700 Subject: [PATCH 001/102] metrics: fix bertscore --- genienlp/metrics.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index 35a613c2..796cee65 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -286,7 +286,11 @@ def computeSM(outputs, targets): def computeBERTScore(outputs, targets, lang): bertscore_metric = load_metric("bertscore") - return sum(bertscore_metric.compute(predictions=outputs, references=targets, lang=lang)['f1']) / len(outputs) * 100 + return ( + sum(bertscore_metric.compute(predictions=outputs, references=targets, lang=lang, use_fast_tokenizer=True)['f1']) + / len(outputs) + * 100 + ) def computeTER(outputs, targets): From 1a337ff20825b67c4905e3df913a094ce22714ef Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 12 Aug 2021 14:14:25 -0700 Subject: [PATCH 002/102] Add adafactor optimizer --- genienlp/arguments.py | 8 ++- genienlp/model_utils/optimizer.py | 100 ++++++++++++++++++++++++++++++ genienlp/train.py | 69 +-------------------- setup.py | 1 + 4 files changed, 108 insertions(+), 70 deletions(-) create mode 100644 genienlp/model_utils/optimizer.py diff --git a/genienlp/arguments.py b/genienlp/arguments.py index fb45b79b..b1e70d1a 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -308,13 +308,17 @@ def parse_argv(parser): help='alternative momentum for Adam (only when not using transformer scheduler), and RAdam', ) parser.add_argument( - '--optimizer', default='adam', choices=['adam', 'adamw', 'sgd', 'radam'], type=str, help='optimizer to use' + '--optimizer', + default='adam', + choices=['adam', 'adamw', 'adafactor', 'radam', 'sgd'], + type=str, + help='optimizer to use', ) parser.add_argument( '--lr_schedule', type=str, default='transformer', - choices=['transformer', 'constant', 'linear', 'sgd', 'cosine'], + choices=['transformer', 'constant', 'linear', 'sgd', 'cosine', 'multi_linear_cosine'], help='The learning rate strategy. All of them can be used with or without warmup.', ) parser.add_argument( diff --git a/genienlp/model_utils/optimizer.py b/genienlp/model_utils/optimizer.py new file mode 100644 index 00000000..88941c03 --- /dev/null +++ b/genienlp/model_utils/optimizer.py @@ -0,0 +1,100 @@ +import math +from functools import partial + +import numpy as np +import torch +from pytorch_lightning_spells.lr_schedulers import LinearLR, MultiStageScheduler +from torch.optim.lr_scheduler import CosineAnnealingLR +from transformers import ( + Adafactor, + AdamW, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_linear_schedule_with_warmup, +) + + +class MultiStageSchedulerV2(MultiStageScheduler): + def __init__(self, schedulers, start_at_epochs, last_epoch=-1): + super().__init__(schedulers, start_at_epochs, last_epoch) + + def get_last_lr(self, epoch=None): + if epoch is None: + self.last_epoch = self.last_epoch + 1 + else: + self.last_epoch = epoch - 1 + for scheduler, starting_epoch in zip(self.schedulers, self.start_at_epochs): + if self.last_epoch + 1 >= starting_epoch: + scheduler.last_epoch = self.last_epoch - starting_epoch + return scheduler.get_last_lr() + + +def get_transformer_learning_rate(i, *, dimension, warmup): + i += 1 + return 1.0 / math.sqrt(dimension) * min(1 / math.sqrt(i), i / (warmup * math.sqrt(warmup))) + + +def get_sgd_learning_rate(i, *, warmup): + i += 1 + return min(math.sqrt(warmup) / math.sqrt(i), i / warmup) + + +def init_opt(args, model, logger): + num_training_steps = sum(args.train_iterations) // args.gradient_accumulation_steps + + if args.optimizer == 'adam': + # Adam with transformer schedule has a different set of default hyperparameters: + if args.lr_schedule == 'transformer': + opt = torch.optim.Adam( + model.params, lr=args.lr_multiply, betas=(0.9, 0.98), eps=1e-9, weight_decay=args.weight_decay + ) + else: + opt = torch.optim.Adam( + model.params, lr=args.lr_multiply, betas=(args.beta0, 0.999), weight_decay=args.weight_decay + ) + elif args.optimizer == 'adamw': + opt = AdamW(model.params, lr=args.lr_multiply, weight_decay=args.weight_decay) + elif args.optimizer == 'adafactor': + opt = Adafactor(model.params, lr=args.lr_multiply, weight_decay=args.weight_decay, relative_step=False) + elif args.optimizer == 'radam': + import radam + + if args.warmup > 1: + logger.warning('With RAdam optimizer, warmup is never applied') + opt = radam.RAdam(model.params, lr=args.lr_multiply, betas=(args.beta0, 0.999), weight_decay=args.weight_decay) + elif args.optimizer == 'sgd': + opt = torch.optim.SGD(model.params, lr=args.lr_multiply, weight_decay=args.weight_decay) + else: + raise ValueError('Invalid optimizer.') + + if args.lr_schedule == 'transformer': + lr_lambda = partial(get_transformer_learning_rate, dimension=args.dimension, warmup=args.warmup) + scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda) + elif args.lr_schedule == 'constant': + scheduler = get_constant_schedule_with_warmup(opt, num_warmup_steps=args.warmup) + elif args.lr_schedule == 'linear': + scheduler = get_linear_schedule_with_warmup( + opt, + num_training_steps=num_training_steps, + num_warmup_steps=args.warmup, + ) + elif args.lr_schedule == 'cosine': + scheduler = get_cosine_schedule_with_warmup( + opt, + num_training_steps=num_training_steps, + num_warmup_steps=args.warmup, + num_cycles=0.5, + ) + elif args.lr_schedule == 'multi_linear_cosine': + lr_durations = [int(num_training_steps * 0.1), int(np.ceil(num_training_steps * 0.9)) + 1] + start_at_epochs = [0] + list(np.cumsum(lr_durations))[:-1] + linear_scheduler = LinearLR(opt, 0.0001, lr_durations[0]) + cosine_scheduler = CosineAnnealingLR(opt, lr_durations[1]) + scheduler = MultiStageSchedulerV2([linear_scheduler, cosine_scheduler], start_at_epochs) + elif args.lr_schedule == 'sgd': + lr_lambda = partial(get_sgd_learning_rate, warmup=args.warmup) + scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda) + else: + raise ValueError('Invalid learning rate scheduler.') + + return opt, scheduler diff --git a/genienlp/train.py b/genienlp/train.py index fdec9899..2a2b4f85 100644 --- a/genienlp/train.py +++ b/genienlp/train.py @@ -35,21 +35,15 @@ import os import time from copy import deepcopy -from functools import partial from pprint import pformat import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from transformers import ( - AdamW, - get_constant_schedule_with_warmup, - get_cosine_schedule_with_warmup, - get_linear_schedule_with_warmup, -) from . import arguments, models from .arguments import save_args +from .model_utils.optimizer import init_opt from .model_utils.parallel_utils import NamedTupleCompatibleDataParallel from .model_utils.saver import Saver from .ned.ned_utils import init_ned_model @@ -210,7 +204,6 @@ def update_fraction(args, task_iteration): next_fraction = args.curriculum_rate * task_iteration elif args.curriculum_strategy == 'exp': next_fraction = args.curriculum_rate * np.exp(task_iteration) - fraction = min(args.curriculum_max_frac, next_fraction) return fraction @@ -633,66 +626,6 @@ def train( logger.info(f'{args.pretrained_model} model is saved to {args.save} without any fine-tuning') -def get_transformer_learning_rate(i, *, dimension, warmup): - i += 1 - return 1.0 / math.sqrt(dimension) * min(1 / math.sqrt(i), i / (warmup * math.sqrt(warmup))) - - -def get_sgd_learning_rate(i, *, warmup): - i += 1 - return min(math.sqrt(warmup) / math.sqrt(i), i / warmup) - - -def init_opt(args, model, logger): - if args.optimizer == 'adam': - # Adam with transformer schedule has a different set of default hyperparameters: - if args.lr_schedule == 'transformer': - opt = torch.optim.Adam( - model.params, lr=args.lr_multiply, betas=(0.9, 0.98), eps=1e-9, weight_decay=args.weight_decay - ) - else: - opt = torch.optim.Adam( - model.params, lr=args.lr_multiply, betas=(args.beta0, 0.999), weight_decay=args.weight_decay - ) - elif args.optimizer == 'adamw': - opt = AdamW(model.params, lr=args.lr_multiply, weight_decay=args.weight_decay) - elif args.optimizer == 'radam': - import radam - - if args.warmup > 1: - logger.warning('With RAdam optimizer, warmup is never applied') - opt = radam.RAdam(model.params, lr=args.lr_multiply, betas=(args.beta0, 0.999), weight_decay=args.weight_decay) - else: - assert args.optimizer == 'sgd' - opt = torch.optim.SGD(model.params, lr=args.lr_multiply, weight_decay=args.weight_decay) - - if args.lr_schedule == 'transformer': - lr_lambda = partial(get_transformer_learning_rate, dimension=args.dimension, warmup=args.warmup) - scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda) - elif args.lr_schedule == 'constant': - scheduler = get_constant_schedule_with_warmup(opt, num_warmup_steps=args.warmup) - elif args.lr_schedule == 'linear': - scheduler = get_linear_schedule_with_warmup( - opt, - num_training_steps=sum(args.train_iterations) // args.gradient_accumulation_steps, - num_warmup_steps=args.warmup, - ) - elif args.lr_schedule == 'cosine': - scheduler = get_cosine_schedule_with_warmup( - opt, - num_training_steps=sum(args.train_iterations) // args.gradient_accumulation_steps, - num_warmup_steps=args.warmup, - num_cycles=0.5, - ) - elif args.lr_schedule == 'sgd': - lr_lambda = partial(get_sgd_learning_rate, warmup=args.warmup) - scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda) - else: - raise ValueError('Invalid learning rate scheduler.') - - return opt, scheduler - - def main(args): args = arguments.post_parse_general(args) args = arguments.post_parse_train_specific(args) diff --git a/setup.py b/setup.py index 10213e49..7310581e 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,7 @@ 'nltk~=3.4', 'ujson~=4.1', 'pathos==0.2.8', + 'pytorch-lightning-spells==0.0.3', # for kf: 'kfserving>=0.5.0', # for NED: From f9c11742c5908076c7598393bbe649e02c810b5a Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 20 Aug 2021 22:31:23 -0700 Subject: [PATCH 003/102] Bump bootleg to 1.0.5 also remove packages that get installed from bootleg --- setup.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/setup.py b/setup.py index 7310581e..246c3dae 100644 --- a/setup.py +++ b/setup.py @@ -61,15 +61,12 @@ 'sentencepiece==0.1.*', 'sentence-transformers==2.0.0', 'mosestokenizer~=1.1', - 'nltk~=3.4', - 'ujson~=4.1', 'pathos==0.2.8', 'pytorch-lightning-spells==0.0.3', # for kf: 'kfserving>=0.5.0', # for NED: 'bootleg==1.0.5', - 'marisa_trie_m==0.7.6', # for calibration: 'scikit-learn~=0.23', 'dill~=0.3', From 1ef5cd0600cc78ba8f719999ff5670ed608825f4 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 20 Aug 2021 23:34:20 -0700 Subject: [PATCH 004/102] Add BiTOD --- genienlp/arguments.py | 7 +++++ genienlp/data_utils/numericalizer.py | 4 --- genienlp/tasks/base_task.py | 1 - genienlp/tasks/generic_dataset.py | 37 ++++++++++++++++++++++- genienlp/tasks/generic_task.py | 44 ++++++++++++++++++++++++++-- genienlp/tasks/hf_task.py | 1 + genienlp/util.py | 3 ++ genienlp/validate.py | 12 +++++++- 8 files changed, 100 insertions(+), 9 deletions(-) diff --git a/genienlp/arguments.py b/genienlp/arguments.py index b1e70d1a..073ff8fe 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -523,6 +523,13 @@ def parse_argv(parser): help='Debugging flag for hf datasets where validation will be performed on train set', ) + parser.add_argument( + '--bitod_validation_task', + default='all', + choices=['dst', 'api', 'response', 'all'], + help='', + ) + def check_and_update_generation_args(args): """ diff --git a/genienlp/data_utils/numericalizer.py b/genienlp/data_utils/numericalizer.py index a65fce33..4f7d25a4 100644 --- a/genienlp/data_utils/numericalizer.py +++ b/genienlp/data_utils/numericalizer.py @@ -270,10 +270,6 @@ def build_vocab(self, vocab_sets, tasks): if self.args.add_entities_to_text != 'off': self._tokenizer.add_tokens(['', '']) - # add special tokens for ambig_qa task - if any(task.name == 'ambig_qa' for task in tasks): - self._tokenizer.add_tokens(['', '

', '']) - existing_special_tokens = self._tokenizer.special_tokens_map # add separator if it doesn't exist. It will be used to concatenate context and question if 'sep_token' not in existing_special_tokens: diff --git a/genienlp/tasks/base_task.py b/genienlp/tasks/base_task.py index 8c749ed6..36515e0d 100644 --- a/genienlp/tasks/base_task.py +++ b/genienlp/tasks/base_task.py @@ -63,7 +63,6 @@ def get_splits(self, root, **kwargs): """ Load the train, test, eval datasets for this task - :param field: the text.Field to use for tokenization, preprocessing and vocabulary construction :param root: the base directory where data is stored :param kwargs: other arguments to pass to the Dataset :return: a list of text.Dataset diff --git a/genienlp/tasks/generic_dataset.py b/genienlp/tasks/generic_dataset.py index dec8d7df..3f9e91e3 100644 --- a/genienlp/tasks/generic_dataset.py +++ b/genienlp/tasks/generic_dataset.py @@ -41,6 +41,7 @@ from typing import Iterable import torch +import ujson from datasets import load_dataset from ..data_utils.example import Example, NumericalizedExamples @@ -1873,7 +1874,7 @@ def __init__(self, data, *, make_example, **kwargs): super().__init__(examples, **kwargs) @classmethod - def return_splits(cls, name, path='.data', train='train', validation='dev', test='test', **kwargs): + def return_splits(cls, path='.data', train='train', validation='dev', test='test', **kwargs): crossner_domains = kwargs.pop('crossner_domains') @@ -1968,3 +1969,37 @@ def splits(cls, root='.data', train='train', validation='eval', test='test', **k ), Split(train=train_path, eval=validation_path, test=test_path), ) + + +class BiTODDataset(CQA): + def __init__(self, path, *, make_example, **kwargs): + subsample = kwargs.pop('subsample') + examples = [] + + with open(path) as fin: + data = ujson.load(fin)['data'] + for turn in data: + examples.append(make_example(turn)) + + if subsample is not None and len(examples) >= subsample: + break + + super().__init__(examples, **kwargs) + + @classmethod + def return_splits(cls, path='.data', train='train', validation='valid', test='test', **kwargs): + train_path, validation_path, test_path = None, None, None + if train: + train_path = os.path.join(path, 'train.json') + if validation: + validation_path = os.path.join(path, f'{validation}.json') + if test: + test_path = os.path.join(path, 'test.json') + + train_data = None if train is None else cls(train_path, **kwargs) + validation_data = None if validation is None else cls(validation_path, **kwargs) + test_data = None if test is None else cls(test_path, **kwargs) + + return Split(train=train_data, eval=validation_data, test=test_data), Split( + train=train_path, eval=validation_path, test=test_path + ) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 8cbf999d..f0d58e36 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -33,7 +33,7 @@ from . import generic_dataset from .almond_task import BaseAlmondTask from .base_task import BaseTask -from .generic_dataset import CrossNERDataset, OODDataset +from .generic_dataset import BiTODDataset, CrossNERDataset, OODDataset from .registry import register_task @@ -377,7 +377,7 @@ def _make_example(self, parts, dir_name=None, **kwargs): ) def get_splits(self, root, **kwargs): - return CrossNERDataset.return_splits(name=self.name, path=root, make_example=self._make_example, **kwargs) + return CrossNERDataset.return_splits(path=root, make_example=self._make_example, **kwargs) @register_task('ood_task') @@ -393,3 +393,43 @@ def metrics(self): def get_splits(self, root, **kwargs): return OODDataset.splits(root=root, **kwargs) + + +@register_task('bitod') +class BiTOD(BaseTask): + def __init__(self, name, args): + super().__init__(name, args) + self.special_tokens = { + 'USER:', + 'SYSTEM:', + '', + '', + '', + '#unknown', + 'DST:', + 'API:', + 'Response:', + } + self._metrics = 'casedbleu' + + def utterance_field(self): + return 'context' + + def _make_example(self, turn, **kwargs): + dial_id, turn_id, context, answer, train_target = ( + turn['dial_id'], + turn['turn_id'], + turn['input_text'], + turn['output_text'], + turn['train_target'], + ) + question = '' + + example_id = '/'.join([dial_id, str(turn_id), train_target]) + + return Example.from_raw( + self.name + '/' + str(example_id), context, question, answer, preprocess=self.preprocess_field, lower=False + ) + + def get_splits(self, root, **kwargs): + return BiTODDataset.return_splits(path=root, make_example=self._make_example, **kwargs) diff --git a/genienlp/tasks/hf_task.py b/genienlp/tasks/hf_task.py index eb21a11a..7891e471 100644 --- a/genienlp/tasks/hf_task.py +++ b/genienlp/tasks/hf_task.py @@ -49,6 +49,7 @@ def utterance_field(self): class AmbigQA(HFTask): def __init__(self, name, args): super().__init__(name, args) + self.special_tokens = {'', '

', ''} @property def metrics(self): diff --git a/genienlp/util.py b/genienlp/util.py index 49efdc9a..11670fe9 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -796,6 +796,7 @@ def load_config_json(args): 'crossner_domains', 'hf_test_overfit', 'override_valid_metrics', + 'bitod_validation_task', ] # train and predict scripts have these arguments in common. We use the values from train only if they are not provided in predict @@ -851,6 +852,8 @@ def load_config_json(args): setattr(args, r, [0]) elif r in ['override_valid_metrics']: setattr(args, r, []) + elif r in ['bitod_validation_task']: + setattr(args, r, 'all') elif r == 'database_type': setattr(args, r, 'json') elif r == 'att_pooling': diff --git a/genienlp/validate.py b/genienlp/validate.py index 9da6044e..267f289d 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -325,9 +325,19 @@ def validate(task, val_iter, model, numericalizer, args, num_print=10): output = generate_with_model(model, val_iter, numericalizer, task, args) + validation_outputs = output + if task.name == 'bitod' and args.bitod_validation_task != 'all': + validation_outputs = GenerationOutput() + for i in range(len(output.example_ids)): + id_, train_task = output.example_ids[i].rsplit('/', 1) + if train_task in args.bitod_validation_task: + validation_outputs.answers.append(output.answers[i]) + validation_outputs.predictions.append(output.predictions[i]) + metrics = calculate_and_reduce_metrics( - output.predictions, output.answers, task.metrics, args.reduce_metrics, model.tgt_lang + validation_outputs.predictions, validation_outputs.answers, task.metrics, args.reduce_metrics, model.tgt_lang ) + results = [output.predictions, output.answers, output.contexts] print_results(names, results, num_print=num_print) From 912ed4662f9551f49f64a4f2ea17133026d7d02b Mon Sep 17 00:00:00 2001 From: mehrad Date: Sun, 22 Aug 2021 16:45:59 -0700 Subject: [PATCH 005/102] bitod: add v1 special tokens --- genienlp/tasks/generic_task.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index f0d58e36..4f6c18f5 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -399,7 +399,19 @@ def get_splits(self, root, **kwargs): class BiTOD(BaseTask): def __init__(self, name, args): super().__init__(name, args) - self.special_tokens = { + special_tokens_v1 = { + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + } + special_tokens_v2 = { 'USER:', 'SYSTEM:', '', @@ -410,6 +422,7 @@ def __init__(self, name, args): 'API:', 'Response:', } + self.special_tokens = special_tokens_v1 | special_tokens_v2 self._metrics = 'casedbleu' def utterance_field(self): From e59be7d227b549a61fa72a852b5cbf41666d381c Mon Sep 17 00:00:00 2001 From: mehrad Date: Sun, 22 Aug 2021 15:15:45 -0700 Subject: [PATCH 006/102] Make it possible to choose loss as deca_score --- genienlp/train.py | 6 +++++- genienlp/validate.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/genienlp/train.py b/genienlp/train.py index 2a2b4f85..95290949 100644 --- a/genienlp/train.py +++ b/genienlp/train.py @@ -245,7 +245,11 @@ def do_validate( metric_entry += f'{metric_key}_{metric_value:.2f}:' metric_entry = metric_entry[:-1] - deca_score += metric_dict[val_task.metrics[0]] + deca_metric = val_task.metrics[0] + if deca_metric == 'loss': + deca_score += val_loss + else: + deca_score += metric_dict[deca_metric] # val log logger.info(log_entry + metric_entry) diff --git a/genienlp/validate.py b/genienlp/validate.py index 267f289d..361594ab 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -334,8 +334,11 @@ def validate(task, val_iter, model, numericalizer, args, num_print=10): validation_outputs.answers.append(output.answers[i]) validation_outputs.predictions.append(output.predictions[i]) + # loss is already calculated + metrics_to_return = [metric for metric in task.metrics if metric != 'loss'] + metrics = calculate_and_reduce_metrics( - validation_outputs.predictions, validation_outputs.answers, task.metrics, args.reduce_metrics, model.tgt_lang + validation_outputs.predictions, validation_outputs.answers, metrics_to_return, args.reduce_metrics, model.tgt_lang ) results = [output.predictions, output.answers, output.contexts] From bf1ffed6a38c5b7f7ab9f75fc230cb8979420b7a Mon Sep 17 00:00:00 2001 From: mehrad Date: Sun, 22 Aug 2021 23:22:09 -0700 Subject: [PATCH 007/102] Save gpu memory --- genienlp/models/transformer_lstm.py | 11 ++++++++--- genienlp/models/transformer_seq2seq.py | 20 +++++++++++++++----- genienlp/train.py | 4 ++-- genienlp/validate.py | 7 ++++--- 4 files changed, 29 insertions(+), 13 deletions(-) diff --git a/genienlp/models/transformer_lstm.py b/genienlp/models/transformer_lstm.py index f90b46e0..13ea18b0 100644 --- a/genienlp/models/transformer_lstm.py +++ b/genienlp/models/transformer_lstm.py @@ -35,6 +35,7 @@ from ..data_utils.numericalizer import TransformerNumericalizer from ..model_utils.transformers_utils import BertModelForNER, XLMRobertaModelForNER +from ..tasks.almond_task import Translate from ..util import adjust_language_code from .base import GenieModel from .identity_encoder import IdentityEncoder @@ -60,6 +61,10 @@ def __init__(self, config=None, *inputs, args, vocab_sets, tasks, save_directory config = AutoConfig.from_pretrained(encoder_embeddings, cache_dir=args.embeddings) args.dimension = config.hidden_size + self._output_scores = any('loss' in task.metrics for task in tasks) + self._output_attentions = any(isinstance(task, Translate) for task in tasks) + self._output_hidden_states = False + self.src_lang, self.tgt_lang = adjust_language_code( config, args.pretrained_model, kwargs.get('src_lang', 'en'), kwargs.get('tgt_lang', 'en') ) @@ -242,9 +247,9 @@ def generate( do_sample=do_sample, generation_dict={'max_output_length': max_output_length}, encoder_output=encoder_output, - output_scores=False, - output_attentions=False, - output_hidden_states=False, + output_scores=self._output_scores, + output_attentions=self._output_attentions, + output_hidden_states=self._output_hidden_states, return_dict_in_generate=True, ) output_ids = generated.sequences diff --git a/genienlp/models/transformer_seq2seq.py b/genienlp/models/transformer_seq2seq.py index 12564740..c619c2c0 100644 --- a/genienlp/models/transformer_seq2seq.py +++ b/genienlp/models/transformer_seq2seq.py @@ -35,6 +35,7 @@ from ..data_utils.numericalizer import TransformerNumericalizer from ..model_utils.transformers_utils import MULTILINGUAL_TOKENIZERS +from ..tasks.almond_task import Translate from ..util import ConfidenceFeatures, adjust_language_code from .base import GenieModel from .common import LabelSmoothingCrossEntropy @@ -53,6 +54,10 @@ def __init__(self, config=None, *inputs, args, tasks, vocab_sets, save_directory args.dimension = config.d_model self._is_bart_large = self.args.pretrained_model == 'facebook/bart-large' + self._output_scores = any('loss' in task.metrics for task in tasks) + self._output_attentions = any(isinstance(task, Translate) for task in tasks) + self._output_hidden_states = False + self.src_lang, self.tgt_lang = adjust_language_code( config, args.pretrained_model, kwargs.get('src_lang', 'en'), kwargs.get('tgt_lang', 'en') ) @@ -109,7 +114,7 @@ def add_new_vocab_from_data(self, tasks, resize_decoder=False): self.model.resize_token_embeddings(self.numericalizer.num_tokens) def forward(self, *input, **kwargs): - if self.training: + if self.training or kwargs.get('train', False): batch = input[0] answer = batch.answer.value @@ -133,7 +138,12 @@ def forward(self, *input, **kwargs): # (3) if `args.dropper_ratio > 0.0`, will perform Loss Truncation # (4) if `args.label_smoothing > 0.0`, will add label smoothing term to loss outputs = self.model( - batch.context.value, labels=answer, attention_mask=(batch.context.value != self.numericalizer.pad_id) + batch.context.value, + labels=answer, + attention_mask=(batch.context.value != self.numericalizer.pad_id), + output_attentions=False, + output_hidden_states=False, + return_dict=True, ) batch_size, vocab_size = outputs.logits.shape[0], outputs.logits.shape[2] loss = self.criterion( @@ -187,9 +197,9 @@ def generate( diversity_penalty=diversity_penalty, no_repeat_ngram_size=no_repeat_ngram_size, do_sample=do_sample, - output_scores=False, - output_attentions=True, - output_hidden_states=False, + output_scores=self._output_scores, + output_attentions=self._output_attentions, + output_hidden_states=self._output_hidden_states, return_dict_in_generate=True, ) diff --git a/genienlp/train.py b/genienlp/train.py index 95290949..fce875ec 100644 --- a/genienlp/train.py +++ b/genienlp/train.py @@ -235,8 +235,8 @@ def do_validate( output, metric_dict = validate(val_task, val_iter, model, numericalizer, args, num_print=args.num_print) val_loss = output.loss if val_loss is not None: - log_entry = f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{train_task.name}:{task_progress}val_{val_task.name}:val_loss{val_loss.item():.4f}:' - writer.add_scalar(f'loss/{val_task.name}/val', val_loss.item(), iteration) + log_entry = f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{train_task.name}:{task_progress}val_{val_task.name}:val_loss_{val_loss:.4f}:' + writer.add_scalar(f'loss/{val_task.name}/val', val_loss, iteration) else: log_entry = f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{train_task.name}:{task_progress}val_{val_task.name}:' diff --git a/genienlp/validate.py b/genienlp/validate.py index 361594ab..f02fdb7f 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -136,11 +136,12 @@ def generate_with_seq2seq_model( do_sample=args.temperature[hyperparameter_idx] != 0, # if temperature==0, we do not sample ) partial_batch_prediction_ids = generated.sequences - cross_attentions = getattr(generated, 'cross_attentions', None) - if cross_attentions is not None: + if model._output_attentions: + cross_attentions = generated.cross_attentions + # stack tensors to shape (max_output_length, num_layers, batch_size, num_heads, 1, max_input_length) - cross_attentions = torch.stack(([torch.stack(tuple) for tuple in cross_attentions])) + cross_attentions = torch.stack(([torch.stack(tuple) for tuple in cross_attentions])).cpu() # reshape to (num_layers, batch_size, num_heads, max_output_length, max_input_length) cross_attentions = cross_attentions.squeeze(4) From acf4fdde7b23957cd8a4846998e5744439b76504 Mon Sep 17 00:00:00 2001 From: mehrad Date: Sun, 22 Aug 2021 23:03:44 -0700 Subject: [PATCH 008/102] Calculate and return loss with forward pass --- genienlp/validate.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index f02fdb7f..f8a1dfb1 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -93,6 +93,7 @@ def generate_with_seq2seq_model( answers contexts """ + total_loss = 0.0 if model._output_scores else None output_confidence_scores = confidence_estimators is not None predictions = [] confidence_features = [] @@ -120,6 +121,10 @@ def generate_with_seq2seq_model( batch_answer = numericalizer.reverse(batch.answer.value.data, 'answer') answers += batch_answer + if total_loss is not None: + loss = model(batch, train=True).loss.item() + total_loss += loss + for hyperparameter_idx in range(len(args.temperature)): generated = model.generate( batch, @@ -182,6 +187,9 @@ def get_example_index(i): predictions += batch_prediction confidence_features += batch_confidence_features + if total_loss is not None: + total_loss /= len(example_ids) + if original_order is not None: # sort back to the original order original_order, example_ids, predictions, answers, contexts, confidence_features = [ @@ -203,9 +211,7 @@ def get_example_index(i): numericalizer._tokenizer.tgt_lang, ) - # TODO calculate and return loss - loss = None - output = GenerationOutput(loss=loss) + output = GenerationOutput(loss=total_loss) if output_predictions_only: output.predictions = predictions @@ -229,6 +235,7 @@ def get_example_index(i): def generate_with_classification_model( model, data_iterator, numericalizer, task, original_order=None, disable_progbar=True ) -> GenerationOutput: + total_loss = 0.0 all_example_ids = [] all_answers = [] all_contexts = [] @@ -241,7 +248,12 @@ def generate_with_classification_model( all_example_ids += batch_example_ids - output = model(input_ids=batch.context.value, attention_mask=(batch.context.value != numericalizer.pad_id)) + # pass labels to get loss + output = model( + input_ids=batch.context.value, + attention_mask=(batch.context.value != numericalizer.pad_id), + labels=batch.answer.value, + ) labels = batch.answer.value.tolist() @@ -271,6 +283,10 @@ def generate_with_classification_model( all_answers += processed_labels all_predictions += processed_preds + total_loss += output.loss + + total_loss /= len(all_example_ids) + if original_order is not None: # sort back to the original order original_order, all_example_ids, all_predictions, all_answers, all_contexts = [ @@ -280,10 +296,8 @@ def generate_with_classification_model( ) ] - # TODO calculate and return loss - loss = None output = GenerationOutput( - loss=loss, example_ids=all_example_ids, contexts=all_contexts, answers=all_answers, predictions=all_predictions + loss=total_loss, example_ids=all_example_ids, contexts=all_contexts, answers=all_answers, predictions=all_predictions ) return output From 9447ff9d08f2a7adc0083fa43795dbd498e18b19 Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 23 Aug 2021 00:54:05 -0700 Subject: [PATCH 009/102] Add option to evaluate bitod e2e --- genienlp/arguments.py | 9 + genienlp/data_utils/iterator.py | 2 +- genienlp/predict.py | 10 ++ genienlp/tasks/base_task.py | 1 + genienlp/tasks/generic_dataset.py | 4 + genienlp/tasks/generic_task.py | 1 + genienlp/util.py | 22 ++- genienlp/validate.py | 284 +++++++++++++++++++++++++++++- setup.py | 2 + 9 files changed, 328 insertions(+), 7 deletions(-) diff --git a/genienlp/arguments.py b/genienlp/arguments.py index 073ff8fe..3430a273 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -529,6 +529,11 @@ def parse_argv(parser): choices=['dst', 'api', 'response', 'all'], help='', ) + parser.add_argument( + '--bitod_e2e_evaluation', + action='store_true', + help='', + ) def check_and_update_generation_args(args): @@ -618,6 +623,10 @@ def post_parse_general(args): def post_parse_train_specific(args): + if args.bitod_e2e_evaluation and args.val_batch_size[0] != 1: + logger.warning('When evaluating bitod end2end val_batch_size should be 1 so we load data turn by turn') + args.val_batch_size = [1] + if len(args.val_batch_size) < len(args.val_task_names): args.val_batch_size = len(args.val_task_names) * args.val_batch_size diff --git a/genienlp/data_utils/iterator.py b/genienlp/data_utils/iterator.py index 444e9e72..566d94c6 100644 --- a/genienlp/data_utils/iterator.py +++ b/genienlp/data_utils/iterator.py @@ -57,7 +57,7 @@ def __init__(self, data_source, batch_size, sort, shuffle_and_repeat, sort_key_f self.batch_size_fn = batch_size_fn self.groups = groups - if sort: + if sort and self.sort_key: # sort while keeping track of the original order data_with_original_order = list( zip(data_source, range(len(data_source))) diff --git a/genienlp/predict.py b/genienlp/predict.py index 314daad0..b821ec53 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -235,6 +235,12 @@ def parse_argv(parser): help='do not preserve quotation marks in the output. Useful if using alignment for semantic parsing or NLG', ) + parser.add_argument( + '--bitod_e2e_evaluation', + action='store_true', + help='', + ) + def set_default_values(args): """ @@ -243,6 +249,10 @@ def set_default_values(args): if args.confidence_feature_path is None: args.confidence_feature_path = os.path.join(args.path, 'confidence_features.pkl') + if args.bitod_e2e_evaluation and args.val_batch_size[0] != 1: + logger.warning('When evaluating bitod end2end val_batch_size should be 1 so we load data turn by turn') + args.val_batch_size = [1] + def check_args(args): diff --git a/genienlp/tasks/base_task.py b/genienlp/tasks/base_task.py index 36515e0d..70cb0850 100644 --- a/genienlp/tasks/base_task.py +++ b/genienlp/tasks/base_task.py @@ -41,6 +41,7 @@ class BaseTask(object): def __init__(self, name, args): self.name = name + self.args = args self._metrics = ['em', 'nem', 'nf1'] # special task-specific tokens that should not be subword tokenized self.special_tokens = set() diff --git a/genienlp/tasks/generic_dataset.py b/genienlp/tasks/generic_dataset.py index 3f9e91e3..0ceab003 100644 --- a/genienlp/tasks/generic_dataset.py +++ b/genienlp/tasks/generic_dataset.py @@ -1985,6 +1985,10 @@ def __init__(self, path, *, make_example, **kwargs): break super().__init__(examples, **kwargs) + + if kwargs['e2e_evaluation']: + self.eval_sort_key_fn = None + self.eval_batch_size_fn = default_batch_fn @classmethod def return_splits(cls, path='.data', train='train', validation='valid', test='test', **kwargs): diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 4f6c18f5..440509fb 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -445,4 +445,5 @@ def _make_example(self, turn, **kwargs): ) def get_splits(self, root, **kwargs): + kwargs['e2e_evaluation'] = self.args.bitod_e2e_evaluation return BiTODDataset.return_splits(path=root, make_example=self._make_example, **kwargs) diff --git a/genienlp/util.py b/genienlp/util.py index 11670fe9..af2c990e 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -583,13 +583,20 @@ def make_data_loader(dataset, numericalizer, batch_size, device=None, train=Fals f'answer lengths (min, mean, max): {np.min(answer_lengths)}, {int(np.mean(answer_lengths))}, {np.max(answer_lengths)}' ) + if train: + sort_key_fn = dataset.sort_key_fn + batch_size_fn = dataset.batch_size_fn + else: + sort_key_fn = getattr(dataset, 'eval_sort_key_fn', dataset.sort_key_fn) + batch_size_fn = getattr(dataset, 'eval_batch_size_fn', dataset.batch_size_fn) + sampler = LengthSortedIterator( all_features, batch_size=batch_size, sort=True, shuffle_and_repeat=train, - sort_key_fn=dataset.sort_key_fn, - batch_size_fn=dataset.batch_size_fn, + sort_key_fn=sort_key_fn, + batch_size_fn=batch_size_fn, groups=dataset.groups, ) # get the sorted data_source @@ -815,12 +822,19 @@ def load_config_json(args): 'reduce_metrics', 'database_dir', ] - # these are true/ false arguments - overwrite_actions = ['do_alignment', 'align_preserve_input_quotation', 'align_remove_output_quotation'] for o in overwrite: if o not in args or getattr(args, o) is None: retrieve.append(o) + + # these are true/ false arguments + overwrite_actions = [ + 'do_alignment', + 'align_preserve_input_quotation', + 'align_remove_output_quotation', + 'bitod_e2e_evaluation', + ] for o in overwrite_actions: + # if argument is True in predict overwrite train; if False retrieve from train if not getattr(args, o, False): retrieve.append(o) diff --git a/genienlp/validate.py b/genienlp/validate.py index f8a1dfb1..9ce15a77 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -27,12 +27,18 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +import copy +import re import sys -from collections import OrderedDict +from collections import OrderedDict, defaultdict import torch +import ujson +from BiToD.evaluate import r_en_API_MAP, span2dict, state2api +from BiToD.knowledgebase import api +from BiToD.preprocess import API_MAP, knowledge2span, read_require_slots, state2span +from .data_utils.example import NumericalizedExamples, SequentialField from .data_utils.progbar import progress_bar from .metrics import compute_metrics from .models import TransformerForSequenceClassification, TransformerForTokenClassification @@ -51,6 +57,17 @@ def generate_with_model( confidence_estimators=None, disable_progbar=True, ): + if args.bitod_e2e_evaluation: + return generate_with_seq2seq_model_for_dialogue( + model, + data_iterator, + numericalizer, + task, + args, + output_predictions_only=output_predictions_only, + original_order=original_order, + disable_progbar=disable_progbar, + ) if isinstance(model, TransformerForTokenClassification) or isinstance(model, TransformerForSequenceClassification): return generate_with_classification_model( @@ -71,6 +88,269 @@ def generate_with_model( ) +def replace_match(input, re_pattern, replacement): + match = re_pattern.search(input).group(1).strip() + return input.replace(match, replacement) + + +def generate_with_seq2seq_model_for_dialogue( + model, + data_iterator, + numericalizer, + task, + args, + output_predictions_only=False, + original_order=None, + disable_progbar=True, +) -> GenerationOutput: + """ + Inputs: + original_order: List of indices. If provided, we will sort the results according to this order + confidence_estimator: if provided, will use it to calculate and output confidence scores + Outputs: predictions if `output_predictions_only` == True, (loss, predictions, answers, contexts) otherwise + loss + predictions: a List of Lists of strings + answers + contexts + """ + + # history_re = re.compile(' (.*?)(?:$|<)') + # last_system_re = re.compile('SYSTEM: (.*?)(?:USER:|$)') + + state_re = re.compile(' (.*?)(?:$|<)') + knowledge_re = re.compile(' (.*?)(?:$|<)') + + bitod_preds = dict() + + predictions = [] + example_ids = [] + answers = [] + contexts = [] + + cur_dial_id = '' + new_state_text = 'null' + + device = model.device + + for k, turn in enumerate(progress_bar(data_iterator, desc='Generating', disable=disable_progbar)): + batch_size = len(turn.example_id) + assert batch_size == 1 + batch_prediction = [] + batch_example_ids = turn.example_id + + example_ids += batch_example_ids + + task_name, dial_id, turn_id, train_target = example_ids[-1].split('/') + turn_id = int(turn_id) + + if cur_dial_id != dial_id: + # new dialogue + cur_dial_id = dial_id + first_turn = True + dialogue_state = {} + bitod_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)} + else: + first_turn = False + + special_tokens = numericalizer._tokenizer.all_special_tokens + batch_tokens = numericalizer.convert_ids_to_tokens(turn.context.value.data, skip_special_tokens=False) + batch_context = [] + # remove only beginning and trailing special tokens + # otherwise the numericalizer.sep_token added between context and question will be lost + for text in batch_tokens: + i = 0 + while text[i] in special_tokens: + i += 1 + j = len(text) - 1 + while text[j] in special_tokens: + j -= 1 + text = text[i : j + 1] + + batch_context.append(numericalizer._tokenizer.convert_tokens_to_string(text)) + + contexts += batch_context + + if not output_predictions_only: + batch_answer = numericalizer.reverse(turn.answer.value.data, 'answer') + batch_answer = [ + task.postprocess_prediction(batch_example_ids[i], batch_answer[i]) for i in range(len(batch_answer)) + ] + answers += batch_answer + + # iterate through turns + hyperparameter_idx = 0 + + # we always use gold history following common practice + + if first_turn: + # first turn is always dst + numericalized_turn = NumericalizedExamples( + example_id=[turn.example_id[0]], + context=SequentialField( + value=turn.context.value[[0]], + length=turn.context.length[[0]], + limited=turn.context.limited[[0]], + feature=None, + ), + answer=SequentialField( + value=turn.answer.value[[0]], + length=turn.answer.value[[0]], + limited=turn.answer.value[[0]], + feature=None, + ), + ) + else: + required_slots = read_require_slots() + required_slots = {API_MAP[k]: v for k, v in required_slots.items()} + api_names = list(required_slots.keys()) + + # find train_target + if train_target == 'dst': + + #### save latest response + bitod_preds[dial_id]["turns"][str(turn_id - 1)]["response"] = predictions[-1] + #### + + input_text = replace_match(contexts[-1], state_re, new_state_text) + + ## if you want to use predicted response instead of gold uncomment the following + # last_sys_pred = predictions[-1][0].strip() + # input_text = replace_match(input_text, last_system_re, last_sys_pred) + + elif train_target == 'api': + + lev = predictions[-1][0].strip() + state_update = span2dict(lev, api_names) + for api_name in state_update: + active_api = api_name + if api_name not in dialogue_state: + dialogue_state[api_name] = state_update[api_name] + else: + dialogue_state[api_name].update(state_update[api_name]) + + #### save latest state + state_to_record = copy.deepcopy(dialogue_state) + state_to_record = {r_en_API_MAP.get(k, k): v for k, v in state_to_record.items()} + bitod_preds[dial_id]["turns"][str(turn_id)]["state"] = state_to_record + #### + + new_state_text = state2span(dialogue_state, required_slots) + + # replace gold state with predicted state + input_text = replace_match(contexts[-1], state_re, new_state_text) + + elif train_target == 'response': + + bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = '' + + do_api_call = predictions[-1][0].strip() + if do_api_call == 'no': + # knowledge is null so just use current input + input_text = contexts[-1] + elif do_api_call == 'yes': + # do api call + api_name = active_api + if api_name in dialogue_state: + constraints = state2api(dialogue_state[api_name]) + + try: + msg = api.call_api( + r_en_API_MAP.get(api_name, api_name), + constraints=[constraints], + ) + except Exception as e: + print(f'Error: {e}') + print(f'Failed API call with api_name: {api_name} and constraints: {constraints}') + msg = [0, 0] + + domain = api_name.split(" ")[0] + + knowledge = defaultdict(dict) + if int(msg[1]) <= 0: + new_knowledge_text = f'( {domain} ) Message = No item available.' + else: + # why does it only choose the first; does the same happen for training data? + knowledge[domain].update(msg[0]) + new_knowledge_text = knowledge2span(knowledge) + + #### save latest api results + bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text + #### + + input_text = replace_match(contexts[-1], knowledge_re, new_knowledge_text) + input_text = replace_match(input_text, state_re, new_state_text) + + else: + raise ValueError(f'API call should be either yes or no but got {do_api_call}') + + else: + raise ValueError(f'Invalid train_target: {train_target}') + + tokenized_contexts = numericalizer.encode_batch([input_text], field_name='context', features=None)[0] + + numericalized_turn = NumericalizedExamples( + example_id=[turn.example_id[0]], + context=SequentialField( + value=torch.tensor([tokenized_contexts.value], device=device), + length=torch.tensor([tokenized_contexts.length], device=device), + limited=torch.tensor([tokenized_contexts.limited], device=device), + feature=None, + ), + answer=SequentialField( + value=turn.answer.value[[0]], + length=turn.answer.value[[0]], + limited=turn.answer.value[[0]], + feature=None, + ), + ) + + generated = model.generate( + numericalized_turn, + max_output_length=args.max_output_length, + num_outputs=args.num_outputs[hyperparameter_idx], + temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, + repetition_penalty=args.repetition_penalty[hyperparameter_idx], + top_k=args.top_k[hyperparameter_idx], + top_p=args.top_p[hyperparameter_idx], + num_beams=args.num_beams[hyperparameter_idx], + num_beam_groups=args.num_beam_groups[hyperparameter_idx], + diversity_penalty=args.diversity_penalty[hyperparameter_idx], + no_repeat_ngram_size=args.no_repeat_ngram_size[hyperparameter_idx], + do_sample=args.temperature[hyperparameter_idx] != 0, + ) + partial_batch_prediction_ids = generated.sequences + + partial_batch_prediction = numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] + + # post-process predictions + partial_batch_prediction = task.postprocess_prediction(batch_example_ids[0], partial_batch_prediction) + + # put them into the right array + batch_prediction.append([partial_batch_prediction]) + + predictions += batch_prediction + + with open('bitod_preds.json', 'w') as fout: + ujson.dump(bitod_preds, fout, indent=2, ensure_ascii=False) + + if original_order is not None: + # sort back to the original order + original_order, example_ids, predictions, answers, contexts = [ + list(a) for a in tuple(zip(*sorted(list(zip(original_order, example_ids, predictions, answers, contexts))))) + ] + + # TODO calculate and return loss + loss = None + output = GenerationOutput(loss=loss) + + if output_predictions_only: + output.predictions = predictions + else: + output.example_ids, output.predictions, output.answers, output.contexts = example_ids, predictions, answers, contexts + + return output + + def generate_with_seq2seq_model( model, data_iterator, diff --git a/setup.py b/setup.py index 246c3dae..83a595c6 100644 --- a/setup.py +++ b/setup.py @@ -71,5 +71,7 @@ 'scikit-learn~=0.23', 'dill~=0.3', 'xgboost~=1.3', + # BiToD + 'BiToD==0.0.1', ], ) From 54d5ca095b5b4871e7adde2e29cf0b0e63fc5972 Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 24 Aug 2021 14:04:16 -0700 Subject: [PATCH 010/102] Add bitod_score --- genienlp/arguments.py | 6 ---- genienlp/metrics.py | 68 +++++++++++++++++++++++++++++++++++++++---- genienlp/predict.py | 7 +++-- genienlp/util.py | 3 -- genienlp/validate.py | 47 ++++++++++++------------------ 5 files changed, 86 insertions(+), 45 deletions(-) diff --git a/genienlp/arguments.py b/genienlp/arguments.py index 3430a273..2dc81512 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -523,12 +523,6 @@ def parse_argv(parser): help='Debugging flag for hf datasets where validation will be performed on train set', ) - parser.add_argument( - '--bitod_validation_task', - default='all', - choices=['dst', 'api', 'response', 'all'], - help='', - ) parser.add_argument( '--bitod_e2e_evaluation', action='store_true', diff --git a/genienlp/metrics.py b/genienlp/metrics.py index 796cee65..62d0ea28 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -27,10 +27,10 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import collections import os import re import string +from collections import Counter, OrderedDict, defaultdict from contextlib import closing from multiprocessing import Pool, cpu_count from subprocess import PIPE, Popen @@ -212,7 +212,7 @@ def lower(text): def f1_score(prediction, ground_truth): prediction_tokens = prediction.split() ground_truth_tokens = ground_truth.split() - common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens) + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if num_same == 0: return 0 @@ -513,6 +513,44 @@ def computeDialogue(greedy, answer): return joint_goal_em, turn_request_em, turn_goal_em, answer +def computeBITOD(greedy, answer, tgt_lang): + num_examples = len(answer) + subtask_metrics_dict = defaultdict(tuple) + + subtasks = ['dst', 'api', 'response'] + subtask_metrics = [['em'], ['em'], ['casedbleu']] + subtask_weights = [1, 1, 1] + + for t in range(len(subtasks)): + preds, golds = [], [] + for i in range(t, num_examples, 3): + preds.append(greedy[i]) + golds.append(answer[i]) + + metrics_to_compute = subtask_metrics[t] + sub_metrics, _ = compute_metrics(preds, golds, metrics_to_compute, tgt_lang) + subtask_metrics_dict[subtasks[t]] = (sub_metrics, len(golds), subtask_weights[t]) + + # TODO how should we aggregate? + bitod_score = 0.0 + weighted_num_examples = 0 + for subtask, (sub_metrics, num_ex, weight) in subtask_metrics_dict.items(): + if subtask == 'dst': + bitod_score += weight * (sub_metrics['em'] * num_ex) + JGA = sub_metrics['em'] + elif subtask == 'api': + bitod_score += weight * (sub_metrics['em'] * num_ex) + api_em = sub_metrics['em'] + elif subtask == 'response': + bitod_score += weight * (sub_metrics['casedbleu'] * num_ex) + response_bleu = sub_metrics['casedbleu'] + weighted_num_examples += weight * num_ex + + bitod_score /= weighted_num_examples + + return bitod_score, JGA, response_bleu, api_em + + def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): """ Inputs: @@ -532,6 +570,11 @@ def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): metric_values = [] if not isinstance(answer[0], list): answer = [[a] for a in answer] + if 'bitod_score' in requested_metrics: + requested_metrics += ['JGA', 'response_bleu', 'api_em'] + bitod_score, JGA, response_bleu, api_em = computeBITOD(greedy, answer, lang) + metric_keys += ['bitod_score', 'JGA', 'response_bleu', 'api_em'] + metric_values += [bitod_score, JGA, response_bleu, api_em] if 'lfem' in requested_metrics: lfem, answer = computeLFEM(greedy, answer) metric_keys += ['lfem'] @@ -541,9 +584,10 @@ def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): avg_dialogue = (joint_goal_em + request_em) / 2 metric_keys += ['joint_goal_em', 'turn_request_em', 'turn_goal_em', 'avg_dialogue'] metric_values += [joint_goal_em, request_em, turn_goal_em, avg_dialogue] - em = computeEM(greedy, answer) - metric_keys += ['em'] - metric_values += [em] + if 'em' in requested_metrics: + em = computeEM(greedy, answer) + metric_keys += ['em'] + metric_values += [em] if 'pem' in requested_metrics: pem = computePartialEM(greedy, answer) metric_keys.append('pem') @@ -644,5 +688,17 @@ def convert_IOB2_to_IOB1(labels): metric_values += [corpus_f1, precision, recall] metric_dict = dict(zip(metric_keys, metric_values)) - metric_dict = collections.OrderedDict((key, metric_dict[key]) for key in requested_metrics) + metric_dict = OrderedDict((key, metric_dict[key]) for key in requested_metrics) return metric_dict, answer + + +def calculate_and_reduce_metrics(predictions, answers, metrics_to_compute, reduce_metrics, lang): + metrics = OrderedDict() + for i in range(len(predictions[0])): + partial_metrics, _ = compute_metrics([p[i] for p in predictions], answers, metrics_to_compute, lang) + for k, v in partial_metrics.items(): + if reduce_metrics == 'max': + metrics[k] = max(metrics.get(k, 0), v) + else: + raise ValueError('Invalid reduce_metrics argument') + return metrics diff --git a/genienlp/predict.py b/genienlp/predict.py index b821ec53..5cdf3048 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -48,6 +48,7 @@ from . import models from .arguments import check_and_update_generation_args from .calibrate import ConfidenceEstimator +from .metrics import calculate_and_reduce_metrics from .ned.ned_utils import init_ned_model from .tasks.registry import get_tasks from .util import ( @@ -60,7 +61,7 @@ set_seed, split_folder_on_disk, ) -from .validate import calculate_and_reduce_metrics, generate_with_model +from .validate import generate_with_model logger = logging.getLogger(__name__) @@ -470,6 +471,7 @@ def run(args, device): if len(generation_output.answers) > 0: metrics_to_compute = task.metrics metrics_to_compute += args.extra_metrics + metrics_to_compute = [metric for metric in task.metrics if metric not in ['loss']] if args.main_metric_only: metrics_to_compute = [metrics_to_compute[0]] metrics = calculate_and_reduce_metrics( @@ -490,7 +492,8 @@ def run(args, device): log_string += f'{score[i]:.3f}, ' log_string += '\n' logger.info(log_string) - logger.info(metrics) + + logger.info(metrics) task_scores[task].append((len(generation_output.answers), metrics[task.metrics[0]])) diff --git a/genienlp/util.py b/genienlp/util.py index af2c990e..83ef67b6 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -803,7 +803,6 @@ def load_config_json(args): 'crossner_domains', 'hf_test_overfit', 'override_valid_metrics', - 'bitod_validation_task', ] # train and predict scripts have these arguments in common. We use the values from train only if they are not provided in predict @@ -866,8 +865,6 @@ def load_config_json(args): setattr(args, r, [0]) elif r in ['override_valid_metrics']: setattr(args, r, []) - elif r in ['bitod_validation_task']: - setattr(args, r, 'all') elif r == 'database_type': setattr(args, r, 'json') elif r == 'att_pooling': diff --git a/genienlp/validate.py b/genienlp/validate.py index 9ce15a77..a4ceb7bd 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -28,9 +28,10 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import copy +import logging import re import sys -from collections import OrderedDict, defaultdict +from collections import defaultdict import torch import ujson @@ -40,10 +41,12 @@ from .data_utils.example import NumericalizedExamples, SequentialField from .data_utils.progbar import progress_bar -from .metrics import compute_metrics +from .metrics import calculate_and_reduce_metrics from .models import TransformerForSequenceClassification, TransformerForTokenClassification from .util import GenerationOutput, merge_translated_sentences +logger = logging.getLogger(__name__) + def generate_with_model( model, @@ -69,7 +72,7 @@ def generate_with_model( disable_progbar=disable_progbar, ) - if isinstance(model, TransformerForTokenClassification) or isinstance(model, TransformerForSequenceClassification): + elif isinstance(model, (TransformerForTokenClassification, TransformerForSequenceClassification)): return generate_with_classification_model( model, data_iterator, numericalizer, task, original_order=original_order, disable_progbar=disable_progbar ) @@ -273,15 +276,20 @@ def generate_with_seq2seq_model_for_dialogue( knowledge[domain].update(msg[0]) new_knowledge_text = knowledge2span(knowledge) - #### save latest api results + #### save latest api results and constraints bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text + bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints) #### input_text = replace_match(contexts[-1], knowledge_re, new_knowledge_text) input_text = replace_match(input_text, state_re, new_state_text) else: - raise ValueError(f'API call should be either yes or no but got {do_api_call}') + logger.error( + f'API call should be either yes or no but got {do_api_call}; seems model is still training, we assume a no' + ) + # knowledge is null so just use current input + input_text = contexts[-1] else: raise ValueError(f'Invalid train_target: {train_target}') @@ -330,6 +338,10 @@ def generate_with_seq2seq_model_for_dialogue( predictions += batch_prediction + #### save last response + bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = predictions[-1] + #### + with open('bitod_preds.json', 'w') as fout: ujson.dump(bitod_preds, fout, indent=2, ensure_ascii=False) @@ -583,18 +595,6 @@ def generate_with_classification_model( return output -def calculate_and_reduce_metrics(predictions, answers, metrics_to_compute, reduce_metrics, lang): - metrics = OrderedDict() - for i in range(len(predictions[0])): - partial_metrics, _ = compute_metrics([p[i] for p in predictions], answers, metrics_to_compute, lang) - for k, v in partial_metrics.items(): - if reduce_metrics == 'max': - metrics[k] = max(metrics.get(k, 0), v) - else: - raise ValueError('Invalid reduce_metrics argument') - return metrics - - def print_results(keys, values, num_print=1): print() start = 0 @@ -620,20 +620,11 @@ def validate(task, val_iter, model, numericalizer, args, num_print=10): output = generate_with_model(model, val_iter, numericalizer, task, args) - validation_outputs = output - if task.name == 'bitod' and args.bitod_validation_task != 'all': - validation_outputs = GenerationOutput() - for i in range(len(output.example_ids)): - id_, train_task = output.example_ids[i].rsplit('/', 1) - if train_task in args.bitod_validation_task: - validation_outputs.answers.append(output.answers[i]) - validation_outputs.predictions.append(output.predictions[i]) - # loss is already calculated - metrics_to_return = [metric for metric in task.metrics if metric != 'loss'] + metrics_to_compute = [metric for metric in task.metrics if metric not in ['loss']] metrics = calculate_and_reduce_metrics( - validation_outputs.predictions, validation_outputs.answers, metrics_to_return, args.reduce_metrics, model.tgt_lang + output.predictions, output.answers, metrics_to_compute, args.reduce_metrics, model.tgt_lang ) results = [output.predictions, output.answers, output.contexts] From 2dbce09d2e480ba0a94db58684d1518b48418a9d Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 24 Aug 2021 15:40:48 -0700 Subject: [PATCH 011/102] metrics: keep only 3 floating points 2 to keep and 1 for breaking ties --- genienlp/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index 62d0ea28..b4eaca9c 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -698,7 +698,7 @@ def calculate_and_reduce_metrics(predictions, answers, metrics_to_compute, reduc partial_metrics, _ = compute_metrics([p[i] for p in predictions], answers, metrics_to_compute, lang) for k, v in partial_metrics.items(): if reduce_metrics == 'max': - metrics[k] = max(metrics.get(k, 0), v) + metrics[k] = max(metrics.get(k, 0), float('{:.3f}'.format(v))) else: raise ValueError('Invalid reduce_metrics argument') return metrics From 395b4bd4dd573263651cb3ac82f17d8b563c4b0d Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 24 Aug 2021 16:51:32 -0700 Subject: [PATCH 012/102] validate: couple of fixes to e2e evaluation loop --- .gitignore | 2 +- genienlp/metrics.py | 2 +- genienlp/validate.py | 164 ++++++++++++++++++++++--------------------- 3 files changed, 87 insertions(+), 81 deletions(-) diff --git a/.gitignore b/.gitignore index 16fdfb48..f9c9edb9 100644 --- a/.gitignore +++ b/.gitignore @@ -25,7 +25,7 @@ models/.DS_Store src/ workdir/ *save*/ -eval_dir/* +eval_dir*/* genieNLP-tests* lightning_logs/ diff --git a/genienlp/metrics.py b/genienlp/metrics.py index b4eaca9c..71c4035e 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -532,7 +532,7 @@ def computeBITOD(greedy, answer, tgt_lang): subtask_metrics_dict[subtasks[t]] = (sub_metrics, len(golds), subtask_weights[t]) # TODO how should we aggregate? - bitod_score = 0.0 + bitod_score, JGA, response_bleu, api_em = 0.0, 0.0, 0.0, 0.0 weighted_num_examples = 0 for subtask, (sub_metrics, num_ex, weight) in subtask_metrics_dict.items(): if subtask == 'dst': diff --git a/genienlp/validate.py b/genienlp/validate.py index a4ceb7bd..f653f39b 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -31,6 +31,7 @@ import logging import re import sys +import time from collections import defaultdict import torch @@ -131,10 +132,13 @@ def generate_with_seq2seq_model_for_dialogue( contexts = [] cur_dial_id = '' - new_state_text = 'null' device = model.device + required_slots = read_require_slots() + required_slots = {API_MAP[k]: v for k, v in required_slots.items()} + api_names = list(required_slots.keys()) + for k, turn in enumerate(progress_bar(data_iterator, desc='Generating', disable=disable_progbar)): batch_size = len(turn.example_id) assert batch_size == 1 @@ -151,6 +155,9 @@ def generate_with_seq2seq_model_for_dialogue( cur_dial_id = dial_id first_turn = True dialogue_state = {} + new_state_text = 'null' + new_knowledge_text = 'null' + active_api = None bitod_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)} else: first_turn = False @@ -159,7 +166,7 @@ def generate_with_seq2seq_model_for_dialogue( batch_tokens = numericalizer.convert_ids_to_tokens(turn.context.value.data, skip_special_tokens=False) batch_context = [] # remove only beginning and trailing special tokens - # otherwise the numericalizer.sep_token added between context and question will be lost + # otherwise the sep_token added between context and question will be lost for text in batch_tokens: i = 0 while text[i] in special_tokens: @@ -187,6 +194,7 @@ def generate_with_seq2seq_model_for_dialogue( if first_turn: # first turn is always dst + assert train_target == 'dst' numericalized_turn = NumericalizedExamples( example_id=[turn.example_id[0]], context=SequentialField( @@ -203,17 +211,7 @@ def generate_with_seq2seq_model_for_dialogue( ), ) else: - required_slots = read_require_slots() - required_slots = {API_MAP[k]: v for k, v in required_slots.items()} - api_names = list(required_slots.keys()) - - # find train_target if train_target == 'dst': - - #### save latest response - bitod_preds[dial_id]["turns"][str(turn_id - 1)]["response"] = predictions[-1] - #### - input_text = replace_match(contexts[-1], state_re, new_state_text) ## if you want to use predicted response instead of gold uncomment the following @@ -221,75 +219,17 @@ def generate_with_seq2seq_model_for_dialogue( # input_text = replace_match(input_text, last_system_re, last_sys_pred) elif train_target == 'api': - - lev = predictions[-1][0].strip() - state_update = span2dict(lev, api_names) - for api_name in state_update: - active_api = api_name - if api_name not in dialogue_state: - dialogue_state[api_name] = state_update[api_name] - else: - dialogue_state[api_name].update(state_update[api_name]) - - #### save latest state - state_to_record = copy.deepcopy(dialogue_state) - state_to_record = {r_en_API_MAP.get(k, k): v for k, v in state_to_record.items()} - bitod_preds[dial_id]["turns"][str(turn_id)]["state"] = state_to_record - #### - new_state_text = state2span(dialogue_state, required_slots) - # replace gold state with predicted state + # replace state input_text = replace_match(contexts[-1], state_re, new_state_text) elif train_target == 'response': + # replace state + input_text = replace_match(contexts[-1], state_re, new_state_text) - bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = '' - - do_api_call = predictions[-1][0].strip() - if do_api_call == 'no': - # knowledge is null so just use current input - input_text = contexts[-1] - elif do_api_call == 'yes': - # do api call - api_name = active_api - if api_name in dialogue_state: - constraints = state2api(dialogue_state[api_name]) - - try: - msg = api.call_api( - r_en_API_MAP.get(api_name, api_name), - constraints=[constraints], - ) - except Exception as e: - print(f'Error: {e}') - print(f'Failed API call with api_name: {api_name} and constraints: {constraints}') - msg = [0, 0] - - domain = api_name.split(" ")[0] - - knowledge = defaultdict(dict) - if int(msg[1]) <= 0: - new_knowledge_text = f'( {domain} ) Message = No item available.' - else: - # why does it only choose the first; does the same happen for training data? - knowledge[domain].update(msg[0]) - new_knowledge_text = knowledge2span(knowledge) - - #### save latest api results and constraints - bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text - bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints) - #### - - input_text = replace_match(contexts[-1], knowledge_re, new_knowledge_text) - input_text = replace_match(input_text, state_re, new_state_text) - - else: - logger.error( - f'API call should be either yes or no but got {do_api_call}; seems model is still training, we assume a no' - ) - # knowledge is null so just use current input - input_text = contexts[-1] + # replace knowledge + input_text = replace_match(input_text, knowledge_re, new_knowledge_text) else: raise ValueError(f'Invalid train_target: {train_target}') @@ -338,11 +278,77 @@ def generate_with_seq2seq_model_for_dialogue( predictions += batch_prediction - #### save last response - bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = predictions[-1] - #### + if train_target == 'dst': + # update dialogue_state + lev = predictions[-1][0].strip() + state_update = span2dict(lev, api_names) + for api_name in state_update: + active_api = api_name + if api_name not in dialogue_state: + dialogue_state[api_name] = state_update[api_name] + else: + dialogue_state[api_name].update(state_update[api_name]) + + #### save latest state + state_to_record = copy.deepcopy(dialogue_state) + state_to_record = {r_en_API_MAP.get(k, k): v for k, v in state_to_record.items()} + bitod_preds[dial_id]["turns"][str(turn_id)]["state"] = state_to_record + #### + + elif train_target == 'api': + new_knowledge_text = 'null' + constraints = {} + + api_name = active_api if active_api else 'None' + + do_api_call = predictions[-1][0].strip() + + if do_api_call == 'yes': + # make api call if required + api_name = active_api + # do api call + if api_name in dialogue_state: + constraints = state2api(dialogue_state[api_name]) + + try: + msg = api.call_api( + r_en_API_MAP.get(api_name, api_name), + constraints=[constraints], + ) + except Exception as e: + print(f'Error: {e}') + print(f'Failed API call with api_name: {api_name} and constraints: {constraints}') + msg = [0, 0] + + domain = api_name.split(" ")[0] + + knowledge = defaultdict(dict) + if int(msg[1]) <= 0: + new_knowledge_text = f'( {domain} ) Message = No item available.' + else: + # why does it only choose the first; does the same happen for training data? + knowledge[domain].update(msg[0]) + new_knowledge_text = knowledge2span(knowledge) + + elif do_api_call == 'no': + # do nothing + pass + else: + logger.error( + f'API call should be either yes or no but got {do_api_call}; seems model is still training so we assume it\'s a no' + ) + + #### save latest api results and constraints + bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text + bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints) + #### + + if train_target == 'response': + #### save latest response + bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = predictions[-1] + #### - with open('bitod_preds.json', 'w') as fout: + with open(f"{int(time.time())}_bitod_preds.json", 'w') as fout: ujson.dump(bitod_preds, fout, indent=2, ensure_ascii=False) if original_order is not None: From 328a06aca7cc15b39372e9a334e1cedf5f230f4d Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 24 Aug 2021 17:42:44 -0700 Subject: [PATCH 013/102] setup.py: link bitod private repo --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 83a595c6..f469c370 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,6 @@ 'dill~=0.3', 'xgboost~=1.3', # BiToD - 'BiToD==0.0.1', + 'BiToD @ git+ssh://git@github.com/Mehrad0711/BiToD@main', ], ) From f3c8b61f3c11ea70b7ea2ffe095bda18b9fd411b Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 25 Aug 2021 16:14:50 -0700 Subject: [PATCH 014/102] validate: diversify valid examples for printing otherwise we only see the longest examples. this is esp. helpful in bitod since there are long texts for directions. --- genienlp/train.py | 13 ++++++------- genienlp/validate.py | 32 ++++++++++++++++++++------------ 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/genienlp/train.py b/genienlp/train.py index fce875ec..f6774805 100644 --- a/genienlp/train.py +++ b/genienlp/train.py @@ -558,13 +558,12 @@ def train( # validate if should_validate(iteration, val_every, resume=args.resume, start_iteration=start_iteration): if args.print_train_examples_too: - names = ['answer', 'context'] - values = [ - numericalizer.reverse(batch.answer.value.data, 'answer'), - numericalizer.reverse(batch.context.value.data, 'context'), - ] - num_print = min(len(values[0]), args.num_print) - print_results(names, values, num_print=num_print) + results = { + 'answer': numericalizer.reverse(batch.answer.value.data, 'answer'), + 'context': numericalizer.reverse(batch.context.value.data, 'context'), + } + num_print = min(len(results['answer']), args.num_print) + print_results(results, num_print) deca_score = do_validate( iteration, diff --git a/genienlp/validate.py b/genienlp/validate.py index f653f39b..e6ca5c7d 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -601,14 +601,24 @@ def generate_with_classification_model( return output -def print_results(keys, values, num_print=1): +def print_results(results, num_print): print() - start = 0 - end = start + num_print - values = [val[start:end] for val in values] - for ex_idx in range(len(values[0])): - for key_idx, key in enumerate(keys): - value = values[key_idx][ex_idx] + + values = list(results.values()) + num_examples = len(values[0]) + process_values = [] + + start = int(num_examples / 4) + end = start + int(num_print / 2) + process_values += [val[start:end] for val in values] + + start = int(3 * num_examples / 4) + end = start + num_print - int(num_print / 2) + process_values += [val[start:end] for val in values] + + for ex_idx in range(len(process_values[0])): + for key_idx, key in enumerate(results.keys()): + value = process_values[key_idx][ex_idx] v = value[0] if isinstance(value, list) else value print(f'{key:>11}: {repr(v)}') print() @@ -622,18 +632,16 @@ def validate(task, val_iter, model, numericalizer, args, num_print=10): # get rid of the DataParallel wrapper model = model.module - names = ['beam search', 'answer', 'context'] - output = generate_with_model(model, val_iter, numericalizer, task, args) # loss is already calculated metrics_to_compute = [metric for metric in task.metrics if metric not in ['loss']] - metrics = calculate_and_reduce_metrics( output.predictions, output.answers, metrics_to_compute, args.reduce_metrics, model.tgt_lang ) - results = [output.predictions, output.answers, output.contexts] - print_results(names, results, num_print=num_print) + results = {'beam search': output.predictions, 'answer': output.answers, 'context': output.contexts} + + print_results(results, num_print) return output, metrics From c01b452006fb221d5d5967b6c2e2dd5ba3937161 Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 25 Aug 2021 17:02:19 -0700 Subject: [PATCH 015/102] demo for bitod --- genienlp/models/base.py | 6 + genienlp/models/transformer_lstm.py | 8 +- genienlp/models/transformer_seq2seq.py | 8 +- genienlp/server.py | 47 +++- genienlp/validate.py | 370 ++++++++++++++++++++----- 5 files changed, 358 insertions(+), 81 deletions(-) diff --git a/genienlp/models/base.py b/genienlp/models/base.py index 2b34910c..c38cd89d 100644 --- a/genienlp/models/base.py +++ b/genienlp/models/base.py @@ -34,6 +34,7 @@ from transformers import PreTrainedModel from ..data_utils.numericalizer import TransformerNumericalizer +from ..tasks.almond_task import Translate logger = logging.getLogger(__name__) @@ -77,3 +78,8 @@ def add_new_vocab_from_data(self, tasks, resize_decoder=False): self.numericalizer.grow_vocab(tasks) if self.numericalizer.num_tokens > old_num_tokens: logger.info(f'Vocabulary has expanded to {self.numericalizer.num_tokens} tokens') + + def set_task_dependent_generation_kwargs(self, tasks): + self._output_scores = any('loss' in task.metrics for task in tasks) + self._output_attentions = any(isinstance(task, Translate) for task in tasks) + self._output_hidden_states = False diff --git a/genienlp/models/transformer_lstm.py b/genienlp/models/transformer_lstm.py index 13ea18b0..13495c11 100644 --- a/genienlp/models/transformer_lstm.py +++ b/genienlp/models/transformer_lstm.py @@ -35,7 +35,6 @@ from ..data_utils.numericalizer import TransformerNumericalizer from ..model_utils.transformers_utils import BertModelForNER, XLMRobertaModelForNER -from ..tasks.almond_task import Translate from ..util import adjust_language_code from .base import GenieModel from .identity_encoder import IdentityEncoder @@ -61,9 +60,10 @@ def __init__(self, config=None, *inputs, args, vocab_sets, tasks, save_directory config = AutoConfig.from_pretrained(encoder_embeddings, cache_dir=args.embeddings) args.dimension = config.hidden_size - self._output_scores = any('loss' in task.metrics for task in tasks) - self._output_attentions = any(isinstance(task, Translate) for task in tasks) - self._output_hidden_states = False + # tasks is not passed during initialization only in server mode + # call this function after task is recognized + if tasks: + self.set_task_dependent_generation_kwargs(tasks) self.src_lang, self.tgt_lang = adjust_language_code( config, args.pretrained_model, kwargs.get('src_lang', 'en'), kwargs.get('tgt_lang', 'en') diff --git a/genienlp/models/transformer_seq2seq.py b/genienlp/models/transformer_seq2seq.py index c619c2c0..7af8b5b3 100644 --- a/genienlp/models/transformer_seq2seq.py +++ b/genienlp/models/transformer_seq2seq.py @@ -35,7 +35,6 @@ from ..data_utils.numericalizer import TransformerNumericalizer from ..model_utils.transformers_utils import MULTILINGUAL_TOKENIZERS -from ..tasks.almond_task import Translate from ..util import ConfidenceFeatures, adjust_language_code from .base import GenieModel from .common import LabelSmoothingCrossEntropy @@ -54,9 +53,10 @@ def __init__(self, config=None, *inputs, args, tasks, vocab_sets, save_directory args.dimension = config.d_model self._is_bart_large = self.args.pretrained_model == 'facebook/bart-large' - self._output_scores = any('loss' in task.metrics for task in tasks) - self._output_attentions = any(isinstance(task, Translate) for task in tasks) - self._output_hidden_states = False + # tasks is not passed during initialization only in server mode + # call this function after task is recognized + if tasks: + self.set_task_dependent_generation_kwargs(tasks) self.src_lang, self.tgt_lang = adjust_language_code( config, args.pretrained_model, kwargs.get('src_lang', 'en'), kwargs.get('tgt_lang', 'en') diff --git a/genienlp/server.py b/genienlp/server.py index 9d00a8df..02764836 100644 --- a/genienlp/server.py +++ b/genienlp/server.py @@ -61,7 +61,12 @@ def parse_argv(parser): '--checkpoint_name', default='best.pth', help='Checkpoint file to use (relative to --path, defaults to best.pth)' ) parser.add_argument('--port', default=8401, type=int, help='TCP port to listen on') - parser.add_argument('--stdin', action='store_true', help='Interact on stdin/stdout instead of TCP') + parser.add_argument( + '--interact_mode', + default='tcp', + choices=['stdin', 'tcp', 'cmdline'], + help='Interact with stdin/stdout, user cmdline, or TCP', + ) parser.add_argument('--database_dir', type=str, help='Database folder containing all relevant files') parser.add_argument('--src_locale', default='en', help='locale tag of the input language to parse') parser.add_argument('--tgt_locale', default='en', help='locale tag of the target language to generate') @@ -119,7 +124,7 @@ def numericalize_examples(self, ex): # make a single batch with all examples return NumericalizedExamples.collate_batches(all_features, self.numericalizer, device=self.device) - def handle_request(self, request): + def handle_request(self, request, interactive=False): task_name = request['task'] if 'task' in request else 'generic' task = list(get_tasks([task_name], self.args, self._cached_task_names).values())[0] if task_name not in self._cached_task_names: @@ -160,6 +165,7 @@ def handle_request(self, request): self.ned_model.process_examples(examples, None, task.utterance_field) self.model.add_new_vocab_from_data([task]) + self.model.set_task_dependent_generation_kwargs([task]) batch = self.numericalize_examples(examples) try: @@ -190,6 +196,18 @@ def handle_request(self, request): for e_idx, estimator_scores in enumerate(output.confidence_scores): instance['score'][self.estimator_filenames[e_idx]] = float(estimator_scores[idx]) response.append(instance) + + elif interactive: + self.args.bitod_e2e_evaluation = True + output = generate_with_model( + self.model, + [batch], + self.numericalizer, + task, + self.args, + output_predictions_only=True, + interactive=True, + ) else: output = generate_with_model( self.model, [batch], self.numericalizer, task, self.args, output_predictions_only=True @@ -212,12 +230,12 @@ def handle_request(self, request): return response - def handle_json_request(self, line: str) -> str: + def handle_json_request(self, line: str, interactive=False) -> str: request = json.loads(line) if 'instances' in request: - return json.dumps({'id': request['id'], 'instances': self.handle_request(request)}) + '\n' + return json.dumps({'id': request['id'], 'instances': self.handle_request(request, interactive)}) + '\n' else: - response = self.handle_request(request) + response = self.handle_request(request, interactive) assert len(response) == 1 response = response[0] response['id'] = request['id'] @@ -259,14 +277,29 @@ def _run_stdin(self): except KeyboardInterrupt: pass + def _run_interactive(self): + try: + while True: + # raw_text = input('USER: ') + raw_text = '' + line = json.dumps( + {"task": "bitod", "id": "0", "context": "USER: " + raw_text.strip(), "question": "", "answer": ""} + ) + sys.stdout.write(self.handle_json_request(line, interactive=True)) + sys.stdout.flush() + except KeyboardInterrupt: + pass + def run(self): log_model_size(logger, self.model, self.args.model) self.model.to(self.device) self.model.eval() - if self.args.stdin: + if self.args.interact_mode == 'stdin': self._run_stdin() - else: + elif self.args.interact_mode == 'cmdline': + self._run_interactive() + elif self.args.interact_mode == 'tcp': self._run_tcp() diff --git a/genienlp/validate.py b/genienlp/validate.py index e6ca5c7d..58bfe398 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -60,18 +60,30 @@ def generate_with_model( original_order=None, confidence_estimators=None, disable_progbar=True, + interactive=False, ): if args.bitod_e2e_evaluation: - return generate_with_seq2seq_model_for_dialogue( - model, - data_iterator, - numericalizer, - task, - args, - output_predictions_only=output_predictions_only, - original_order=original_order, - disable_progbar=disable_progbar, - ) + if interactive: + return generate_with_seq2seq_model_for_dialogue_interactive( + model, + numericalizer, + task, + args, + output_predictions_only=output_predictions_only, + original_order=original_order, + disable_progbar=disable_progbar, + ) + else: + return generate_with_seq2seq_model_for_dialogue( + model, + data_iterator, + numericalizer, + task, + args, + output_predictions_only=output_predictions_only, + original_order=original_order, + disable_progbar=disable_progbar, + ) elif isinstance(model, (TransformerForTokenClassification, TransformerForSequenceClassification)): return generate_with_classification_model( @@ -97,6 +109,235 @@ def replace_match(input, re_pattern, replacement): return input.replace(match, replacement) +def generate_with_seq2seq_model_for_dialogue_interactive( + model, + numericalizer, + task, + args, + output_predictions_only=False, + original_order=None, + disable_progbar=True, +) -> GenerationOutput: + + bitod_preds = dict() + + predictions = [] + example_ids = [] + answers = [] + contexts = [] + + device = model.device + + required_slots = read_require_slots() + required_slots = {API_MAP[k]: v for k, v in required_slots.items()} + api_names = list(required_slots.keys()) + + dial_id = 'none' + turn_id = 1 + dialogue_state = {} + new_state_text = 'null' + new_knowledge_text = 'null' + active_api = None + bitod_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)} + + convo_history = [] + convo_window = 3 + + train_target = 'response' + + next_target = {'dst': 'api', 'api': 'response', 'response': 'dst'} + + while True: + + try: + + batch_prediction = [] + + # becomes dst for first turn + train_target = next_target[train_target] + + hyperparameter_idx = 0 + + if train_target == 'dst': + + print('convo_history: ', convo_history) + # Hello, I am looking for a restaurant with Vegan Options. + # construct new input + raw_user_input = input('USER: ') + if raw_user_input == 'RESET': + generate_with_seq2seq_model_for_dialogue_interactive( + model, + numericalizer, + task, + args, + output_predictions_only=False, + original_order=None, + disable_progbar=True, + ) + break + elif raw_user_input == 'END': + break + + raw_user_input = 'USER: ' + raw_user_input.strip() + + convo_history.append(raw_user_input) + + input_text = f'DST: {new_state_text} {" ".join(convo_history[-convo_window:])}' + + elif train_target == 'api': + new_state_text = state2span(dialogue_state, required_slots) + + # replace state + input_text = f'API: {new_state_text} {" ".join(convo_history[-convo_window:])}' + + elif train_target == 'response': + + input_text = f'Response: {new_knowledge_text} {new_state_text} {" ".join(convo_history[-convo_window:])}' + + else: + raise ValueError(f'Invalid train_target: {train_target}') + + tokenized_contexts = numericalizer.encode_batch([input_text], field_name='context', features=None)[0] + + numericalized_turn = NumericalizedExamples( + example_id=[str(turn_id)], + context=SequentialField( + value=torch.tensor([tokenized_contexts.value], device=device), + length=torch.tensor([tokenized_contexts.length], device=device), + limited=torch.tensor([tokenized_contexts.limited], device=device), + feature=None, + ), + answer=SequentialField( + value=None, + length=None, + limited=None, + feature=None, + ), + ) + + generated = model.generate( + numericalized_turn, + max_output_length=args.max_output_length, + num_outputs=args.num_outputs[hyperparameter_idx], + temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, + repetition_penalty=args.repetition_penalty[hyperparameter_idx], + top_k=args.top_k[hyperparameter_idx], + top_p=args.top_p[hyperparameter_idx], + num_beams=args.num_beams[hyperparameter_idx], + num_beam_groups=args.num_beam_groups[hyperparameter_idx], + diversity_penalty=args.diversity_penalty[hyperparameter_idx], + no_repeat_ngram_size=args.no_repeat_ngram_size[hyperparameter_idx], + do_sample=args.temperature[hyperparameter_idx] != 0, + ) + + partial_batch_prediction_ids = generated.sequences + + partial_batch_prediction = numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] + + # post-process predictions + partial_batch_prediction = task.postprocess_prediction(turn_id, partial_batch_prediction) + + # put them into the right array + batch_prediction.append([partial_batch_prediction]) + + predictions += batch_prediction + + if train_target == 'dst': + # update dialogue_state + lev = predictions[-1][0].strip() + state_update = span2dict(lev, api_names) + for api_name in state_update: + active_api = api_name + if api_name not in dialogue_state: + dialogue_state[api_name] = state_update[api_name] + else: + dialogue_state[api_name].update(state_update[api_name]) + + #### save latest state + state_to_record = copy.deepcopy(dialogue_state) + state_to_record = {r_en_API_MAP.get(k, k): v for k, v in state_to_record.items()} + bitod_preds[dial_id]["turns"][str(turn_id)]["state"] = state_to_record + #### + + elif train_target == 'api': + new_knowledge_text = 'null' + constraints = {} + + api_name = active_api if active_api else 'None' + + do_api_call = predictions[-1][0].strip() + + if do_api_call == 'yes': + # make api call if required + api_name = active_api + # do api call + if api_name in dialogue_state: + constraints = state2api(dialogue_state[api_name]) + + try: + msg = api.call_api( + r_en_API_MAP.get(api_name, api_name), + constraints=[constraints], + ) + except Exception as e: + print(f'Error: {e}') + print(f'Failed API call with api_name: {api_name} and constraints: {constraints}') + msg = [0, 0] + + domain = api_name.split(" ")[0] + + knowledge = defaultdict(dict) + if int(msg[1]) <= 0: + new_knowledge_text = f'( {domain} ) Message = No item available.' + else: + # why does it only choose the first; does the same happen for training data? + knowledge[domain].update(msg[0]) + new_knowledge_text = knowledge2span(knowledge) + + elif do_api_call == 'no': + # do nothing + pass + else: + logger.error( + f'API call should be either yes or no but got {do_api_call}; seems model is still training so we assume it\'s a no' + ) + + #### save latest api results and constraints + bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text + bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints) + #### + + if train_target == 'response': + #### save latest response + bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = predictions[-1] + #### + + convo_history.append(predictions[-1][0]) + + except KeyboardInterrupt: + break + + with open(f"{int(time.time())}_bitod_preds.json", 'w') as fout: + ujson.dump(bitod_preds, fout, indent=2, ensure_ascii=False) + + if original_order is not None: + # sort back to the original order + original_order, example_ids, predictions, answers, contexts = [ + list(a) for a in tuple(zip(*sorted(list(zip(original_order, example_ids, predictions, answers, contexts))))) + ] + + # TODO calculate and return loss + loss = None + output = GenerationOutput(loss=loss) + + if output_predictions_only: + output.predictions = predictions + else: + output.example_ids, output.predictions, output.answers, output.contexts = example_ids, predictions, answers, contexts + + return output + + def generate_with_seq2seq_model_for_dialogue( model, data_iterator, @@ -153,14 +394,11 @@ def generate_with_seq2seq_model_for_dialogue( if cur_dial_id != dial_id: # new dialogue cur_dial_id = dial_id - first_turn = True dialogue_state = {} new_state_text = 'null' new_knowledge_text = 'null' active_api = None bitod_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)} - else: - first_turn = False special_tokens = numericalizer._tokenizer.all_special_tokens batch_tokens = numericalizer.convert_ids_to_tokens(turn.context.value.data, skip_special_tokens=False) @@ -192,65 +430,65 @@ def generate_with_seq2seq_model_for_dialogue( # we always use gold history following common practice - if first_turn: - # first turn is always dst - assert train_target == 'dst' - numericalized_turn = NumericalizedExamples( - example_id=[turn.example_id[0]], - context=SequentialField( - value=turn.context.value[[0]], - length=turn.context.length[[0]], - limited=turn.context.limited[[0]], - feature=None, - ), - answer=SequentialField( - value=turn.answer.value[[0]], - length=turn.answer.value[[0]], - limited=turn.answer.value[[0]], - feature=None, - ), - ) - else: - if train_target == 'dst': - input_text = replace_match(contexts[-1], state_re, new_state_text) - - ## if you want to use predicted response instead of gold uncomment the following - # last_sys_pred = predictions[-1][0].strip() - # input_text = replace_match(input_text, last_system_re, last_sys_pred) - - elif train_target == 'api': - new_state_text = state2span(dialogue_state, required_slots) + # if first_turn: + # # first turn is always dst + # assert train_target == 'dst' + # numericalized_turn = NumericalizedExamples( + # example_id=[turn.example_id[0]], + # context=SequentialField( + # value=turn.context.value[[0]], + # length=turn.context.length[[0]], + # limited=turn.context.limited[[0]], + # feature=None, + # ), + # answer=SequentialField( + # value=turn.answer.value[[0]], + # length=turn.answer.value[[0]], + # limited=turn.answer.value[[0]], + # feature=None, + # ), + # ) + # else: + if train_target == 'dst': + input_text = replace_match(contexts[-1], state_re, new_state_text) - # replace state - input_text = replace_match(contexts[-1], state_re, new_state_text) + ## if you want to use predicted response instead of gold uncomment the following + # last_sys_pred = predictions[-1][0].strip() + # input_text = replace_match(input_text, last_system_re, last_sys_pred) - elif train_target == 'response': - # replace state - input_text = replace_match(contexts[-1], state_re, new_state_text) + elif train_target == 'api': + new_state_text = state2span(dialogue_state, required_slots) - # replace knowledge - input_text = replace_match(input_text, knowledge_re, new_knowledge_text) + # replace state + input_text = replace_match(contexts[-1], state_re, new_state_text) - else: - raise ValueError(f'Invalid train_target: {train_target}') + elif train_target == 'response': + # replace state + input_text = replace_match(contexts[-1], state_re, new_state_text) - tokenized_contexts = numericalizer.encode_batch([input_text], field_name='context', features=None)[0] + # replace knowledge + input_text = replace_match(input_text, knowledge_re, new_knowledge_text) - numericalized_turn = NumericalizedExamples( - example_id=[turn.example_id[0]], - context=SequentialField( - value=torch.tensor([tokenized_contexts.value], device=device), - length=torch.tensor([tokenized_contexts.length], device=device), - limited=torch.tensor([tokenized_contexts.limited], device=device), - feature=None, - ), - answer=SequentialField( - value=turn.answer.value[[0]], - length=turn.answer.value[[0]], - limited=turn.answer.value[[0]], - feature=None, - ), - ) + else: + raise ValueError(f'Invalid train_target: {train_target}') + + tokenized_contexts = numericalizer.encode_batch([input_text], field_name='context', features=None)[0] + + numericalized_turn = NumericalizedExamples( + example_id=[turn.example_id[0]], + context=SequentialField( + value=torch.tensor([tokenized_contexts.value], device=device), + length=torch.tensor([tokenized_contexts.length], device=device), + limited=torch.tensor([tokenized_contexts.limited], device=device), + feature=None, + ), + answer=SequentialField( + value=turn.answer.value[[0]], + length=turn.answer.value[[0]], + limited=turn.answer.value[[0]], + feature=None, + ), + ) generated = model.generate( numericalized_turn, From 163224194861897c3eadfd44139fb7f16c73a5d0 Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 25 Aug 2021 22:52:22 -0700 Subject: [PATCH 016/102] validate: fix printing --- genienlp/validate.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index 58bfe398..5a90e80f 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -844,19 +844,24 @@ def print_results(results, num_print): values = list(results.values()) num_examples = len(values[0]) - process_values = [] + # examples are sorted by length + # to get good diversity, get half of examples from second quartile start = int(num_examples / 4) end = start + int(num_print / 2) - process_values += [val[start:end] for val in values] + first_list = [val[start:end] for val in values] + # and the other half from fourth quartile start = int(3 * num_examples / 4) end = start + num_print - int(num_print / 2) - process_values += [val[start:end] for val in values] + second_list = [val[start:end] for val in values] - for ex_idx in range(len(process_values[0])): + # join examples + processed_values = [first + second for first, second in zip(first_list, second_list)] + + for ex_idx in range(len(processed_values[0])): for key_idx, key in enumerate(results.keys()): - value = process_values[key_idx][ex_idx] + value = processed_values[key_idx][ex_idx] v = value[0] if isinstance(value, list) else value print(f'{key:>11}: {repr(v)}') print() From f1f8984edafb9fba9114caf08c8dde4b5b6c931c Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 27 Aug 2021 13:15:45 -0700 Subject: [PATCH 017/102] Write bitod_preds to eval_dir --- genienlp/predict.py | 1 + genienlp/validate.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/genienlp/predict.py b/genienlp/predict.py index 5cdf3048..5fc07e67 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -447,6 +447,7 @@ def run(args, device): output_confidence_features=args.save_confidence_features, confidence_estimators=confidence_estimators, disable_progbar=False, + eval_dir=eval_dir, ) if args.save_confidence_features: diff --git a/genienlp/validate.py b/genienlp/validate.py index 5a90e80f..ef0e0885 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -29,6 +29,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import copy import logging +import os import re import sys import time @@ -61,6 +62,7 @@ def generate_with_model( confidence_estimators=None, disable_progbar=True, interactive=False, + eval_dir=None, ): if args.bitod_e2e_evaluation: if interactive: @@ -80,6 +82,7 @@ def generate_with_model( numericalizer, task, args, + eval_dir, output_predictions_only=output_predictions_only, original_order=original_order, disable_progbar=disable_progbar, @@ -344,6 +347,7 @@ def generate_with_seq2seq_model_for_dialogue( numericalizer, task, args, + eval_dir, output_predictions_only=False, original_order=None, disable_progbar=True, @@ -586,7 +590,7 @@ def generate_with_seq2seq_model_for_dialogue( bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = predictions[-1] #### - with open(f"{int(time.time())}_bitod_preds.json", 'w') as fout: + with open(os.path.join(eval_dir, 'bitod_preds.json'), 'w') as fout: ujson.dump(bitod_preds, fout, indent=2, ensure_ascii=False) if original_order is not None: From ddf37e936d315a28e27ee820b6a571cba460eb8f Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 27 Aug 2021 15:14:20 -0700 Subject: [PATCH 018/102] Update bitod special_tokens for v5 --- genienlp/tasks/generic_task.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 440509fb..eebdf7ec 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -422,7 +422,8 @@ def __init__(self, name, args): 'API:', 'Response:', } - self.special_tokens = special_tokens_v1 | special_tokens_v2 + special_tokens_v5 = {'AGENT_ACTS:'} + self.special_tokens = special_tokens_v1 | special_tokens_v2 | special_tokens_v5 self._metrics = 'casedbleu' def utterance_field(self): From bdcb7aa73ec40fea7f0b520b57a11e7300f4723e Mon Sep 17 00:00:00 2001 From: mehrad Date: Sat, 28 Aug 2021 17:30:42 -0700 Subject: [PATCH 019/102] validate: warn if no messages are retrieved --- genienlp/validate.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index ef0e0885..93380b37 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -283,8 +283,8 @@ def generate_with_seq2seq_model_for_dialogue_interactive( constraints=[constraints], ) except Exception as e: - print(f'Error: {e}') - print(f'Failed API call with api_name: {api_name} and constraints: {constraints}') + logger.error(f'Error: {e}') + logger.error(f'Failed API call with api_name: {api_name} and constraints: {constraints}') msg = [0, 0] domain = api_name.split(" ")[0] @@ -558,14 +558,19 @@ def generate_with_seq2seq_model_for_dialogue( constraints=[constraints], ) except Exception as e: - print(f'Error: {e}') - print(f'Failed API call with api_name: {api_name} and constraints: {constraints}') + logger.error(f'Error: {e}') + logger.error( + f'Failed API call with api_name: {api_name} and constraints: {constraints} for turn: {dial_id}/{turn_id}' + ) msg = [0, 0] domain = api_name.split(" ")[0] knowledge = defaultdict(dict) if int(msg[1]) <= 0: + logger.warning( + f'Message = No item available for api_name: {api_name} and constraints: {constraints} for turn: {dial_id}/{turn_id}' + ) new_knowledge_text = f'( {domain} ) Message = No item available.' else: # why does it only choose the first; does the same happen for training data? From b8f476975e491ed8a4e3f7de863db12fa32d4207 Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 30 Aug 2021 14:05:34 -0700 Subject: [PATCH 020/102] Refactoring --- genienlp/dial_validate.py | 247 +++++++++++++++++++++++++++++++++ genienlp/validate.py | 285 +++----------------------------------- setup.py | 1 + 3 files changed, 267 insertions(+), 266 deletions(-) create mode 100644 genienlp/dial_validate.py diff --git a/genienlp/dial_validate.py b/genienlp/dial_validate.py new file mode 100644 index 00000000..17f1b64f --- /dev/null +++ b/genienlp/dial_validate.py @@ -0,0 +1,247 @@ +import copy +import logging +import sys +import time +from collections import defaultdict + +import torch +import ujson +from BiToD.evaluate import r_en_API_MAP +from BiToD.knowledgebase import api +from BiToD.knowledgebase.en_zh_mappings import API_MAP +from BiToD.utils import knowledge2span, read_require_slots, span2state, state2constraints, state2span + +from genienlp.data_utils.example import NumericalizedExamples, SequentialField +from genienlp.util import GenerationOutput + +logger = logging.getLogger(__name__) + + +def generate_with_seq2seq_model_for_dialogue_interactive( + model, + numericalizer, + task, + args, + output_predictions_only=False, + original_order=None, + disable_progbar=True, +) -> GenerationOutput: + + bitod_preds = dict() + + predictions = [] + example_ids = [] + answers = [] + contexts = [] + + device = model.device + + required_slots = read_require_slots() + required_slots = {API_MAP[k]: v for k, v in required_slots.items()} + api_names = list(required_slots.keys()) + + dial_id = 'none' + turn_id = 1 + dialogue_state = {} + new_state_text = 'null' + new_knowledge_text = 'null' + active_api = None + bitod_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)} + + convo_history = [] + convo_window = 3 + + train_target = 'response' + + next_target = {'dst': 'api', 'api': 'response', 'response': 'dst'} + + while True: + try: + + batch_prediction = [] + + # becomes dst for first turn + train_target = next_target[train_target] + + hyperparameter_idx = 0 + + if train_target == 'dst': + if convo_history: + print(convo_history[-1]) + else: + print('Hello! What are you looking for today?') + # Hello, I am looking for a restaurant with Vegan Options. + # construct new input + raw_user_input = input('USER: ') + if raw_user_input == 'RESET': + generate_with_seq2seq_model_for_dialogue_interactive( + model, + numericalizer, + task, + args, + output_predictions_only=False, + original_order=None, + disable_progbar=True, + ) + break + elif raw_user_input == 'END': + sys.exit(0) + + raw_user_input = 'USER: ' + raw_user_input.strip() + + convo_history.append(raw_user_input) + + input_text = f'DST: {new_state_text} {" ".join(convo_history[-convo_window:])}' + + elif train_target == 'api': + new_state_text = state2span(dialogue_state, required_slots) + + # replace state + input_text = f'API: {new_state_text} {" ".join(convo_history[-convo_window:])}' + + elif train_target == 'response': + + input_text = f'Response: {new_knowledge_text} {new_state_text} {" ".join(convo_history[-convo_window:])}' + + else: + raise ValueError(f'Invalid train_target: {train_target}') + + tokenized_contexts = numericalizer.encode_batch([input_text], field_name='context', features=None)[0] + + numericalized_turn = NumericalizedExamples( + example_id=[str(turn_id)], + context=SequentialField( + value=torch.tensor([tokenized_contexts.value], device=device), + length=torch.tensor([tokenized_contexts.length], device=device), + limited=torch.tensor([tokenized_contexts.limited], device=device), + feature=None, + ), + answer=SequentialField( + value=None, + length=None, + limited=None, + feature=None, + ), + ) + + generated = model.generate( + numericalized_turn, + max_output_length=args.max_output_length, + num_outputs=args.num_outputs[hyperparameter_idx], + temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, + repetition_penalty=args.repetition_penalty[hyperparameter_idx], + top_k=args.top_k[hyperparameter_idx], + top_p=args.top_p[hyperparameter_idx], + num_beams=args.num_beams[hyperparameter_idx], + num_beam_groups=args.num_beam_groups[hyperparameter_idx], + diversity_penalty=args.diversity_penalty[hyperparameter_idx], + no_repeat_ngram_size=args.no_repeat_ngram_size[hyperparameter_idx], + do_sample=args.temperature[hyperparameter_idx] != 0, + ) + + partial_batch_prediction_ids = generated.sequences + + partial_batch_prediction = numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] + + # post-process predictions + partial_batch_prediction = task.postprocess_prediction(turn_id, partial_batch_prediction) + + # put them into the right array + batch_prediction.append([partial_batch_prediction]) + + predictions += batch_prediction + + if train_target == 'dst': + # update dialogue_state + lev = predictions[-1][0].strip() + state_update = span2state(lev, api_names) + for api_name in state_update: + active_api = api_name + if api_name not in dialogue_state: + dialogue_state[api_name] = state_update[api_name] + else: + dialogue_state[api_name].update(state_update[api_name]) + + #### save latest state + state_to_record = copy.deepcopy(dialogue_state) + state_to_record = {r_en_API_MAP.get(k, k): v for k, v in state_to_record.items()} + bitod_preds[dial_id]["turns"][str(turn_id)]["state"] = state_to_record + #### + + elif train_target == 'api': + new_knowledge_text = 'null' + constraints = {} + + api_name = active_api if active_api else 'None' + + do_api_call = predictions[-1][0].strip() + + if do_api_call == 'yes': + # make api call if required + api_name = active_api + # do api call + if api_name in dialogue_state: + constraints = state2constraints(dialogue_state[api_name]) + + try: + msg = api.call_api( + r_en_API_MAP.get(api_name, api_name), + constraints=[constraints], + ) + except Exception as e: + logger.error(f'Error: {e}') + logger.error(f'Failed API call with api_name: {api_name} and constraints: {constraints}') + msg = [0, 0] + + domain = api_name.split(" ")[0] + + knowledge = defaultdict(dict) + if int(msg[1]) <= 0: + new_knowledge_text = f'( {domain} ) Message = No item available.' + else: + # why does it only choose the first; does the same happen for training data? + knowledge[domain].update(msg[0]) + new_knowledge_text = knowledge2span(knowledge) + + elif do_api_call == 'no': + # do nothing + pass + else: + logger.error( + f'API call should be either yes or no but got {do_api_call}; seems model is still training so we assume it\'s a no' + ) + + #### save latest api results and constraints + bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text + bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints) + #### + + if train_target == 'response': + #### save latest response + bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = predictions[-1] + #### + + convo_history.append('SYSTEM: ' + predictions[-1][0]) + + except KeyboardInterrupt: + break + + with open(f"{int(time.time())}_bitod_preds.json", 'w') as fout: + ujson.dump(bitod_preds, fout, indent=2, ensure_ascii=False) + + if original_order is not None: + # sort back to the original order + original_order, example_ids, predictions, answers, contexts = [ + list(a) for a in tuple(zip(*sorted(list(zip(original_order, example_ids, predictions, answers, contexts))))) + ] + + # TODO calculate and return loss + loss = None + output = GenerationOutput(loss=loss) + + if output_predictions_only: + output.predictions = predictions + else: + output.example_ids, output.predictions, output.answers, output.contexts = example_ids, predictions, answers, contexts + + return output diff --git a/genienlp/validate.py b/genienlp/validate.py index 93380b37..902f4a01 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -32,17 +32,18 @@ import os import re import sys -import time from collections import defaultdict +import dictdiffer import torch import ujson -from BiToD.evaluate import r_en_API_MAP, span2dict, state2api +from BiToD.evaluate import r_en_API_MAP, span2state, state2constraints from BiToD.knowledgebase import api from BiToD.preprocess import API_MAP, knowledge2span, read_require_slots, state2span from .data_utils.example import NumericalizedExamples, SequentialField from .data_utils.progbar import progress_bar +from .dial_validate import generate_with_seq2seq_model_for_dialogue_interactive from .metrics import calculate_and_reduce_metrics from .models import TransformerForSequenceClassification, TransformerForTokenClassification from .util import GenerationOutput, merge_translated_sentences @@ -112,235 +113,6 @@ def replace_match(input, re_pattern, replacement): return input.replace(match, replacement) -def generate_with_seq2seq_model_for_dialogue_interactive( - model, - numericalizer, - task, - args, - output_predictions_only=False, - original_order=None, - disable_progbar=True, -) -> GenerationOutput: - - bitod_preds = dict() - - predictions = [] - example_ids = [] - answers = [] - contexts = [] - - device = model.device - - required_slots = read_require_slots() - required_slots = {API_MAP[k]: v for k, v in required_slots.items()} - api_names = list(required_slots.keys()) - - dial_id = 'none' - turn_id = 1 - dialogue_state = {} - new_state_text = 'null' - new_knowledge_text = 'null' - active_api = None - bitod_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)} - - convo_history = [] - convo_window = 3 - - train_target = 'response' - - next_target = {'dst': 'api', 'api': 'response', 'response': 'dst'} - - while True: - - try: - - batch_prediction = [] - - # becomes dst for first turn - train_target = next_target[train_target] - - hyperparameter_idx = 0 - - if train_target == 'dst': - - print('convo_history: ', convo_history) - # Hello, I am looking for a restaurant with Vegan Options. - # construct new input - raw_user_input = input('USER: ') - if raw_user_input == 'RESET': - generate_with_seq2seq_model_for_dialogue_interactive( - model, - numericalizer, - task, - args, - output_predictions_only=False, - original_order=None, - disable_progbar=True, - ) - break - elif raw_user_input == 'END': - break - - raw_user_input = 'USER: ' + raw_user_input.strip() - - convo_history.append(raw_user_input) - - input_text = f'DST: {new_state_text} {" ".join(convo_history[-convo_window:])}' - - elif train_target == 'api': - new_state_text = state2span(dialogue_state, required_slots) - - # replace state - input_text = f'API: {new_state_text} {" ".join(convo_history[-convo_window:])}' - - elif train_target == 'response': - - input_text = f'Response: {new_knowledge_text} {new_state_text} {" ".join(convo_history[-convo_window:])}' - - else: - raise ValueError(f'Invalid train_target: {train_target}') - - tokenized_contexts = numericalizer.encode_batch([input_text], field_name='context', features=None)[0] - - numericalized_turn = NumericalizedExamples( - example_id=[str(turn_id)], - context=SequentialField( - value=torch.tensor([tokenized_contexts.value], device=device), - length=torch.tensor([tokenized_contexts.length], device=device), - limited=torch.tensor([tokenized_contexts.limited], device=device), - feature=None, - ), - answer=SequentialField( - value=None, - length=None, - limited=None, - feature=None, - ), - ) - - generated = model.generate( - numericalized_turn, - max_output_length=args.max_output_length, - num_outputs=args.num_outputs[hyperparameter_idx], - temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, - repetition_penalty=args.repetition_penalty[hyperparameter_idx], - top_k=args.top_k[hyperparameter_idx], - top_p=args.top_p[hyperparameter_idx], - num_beams=args.num_beams[hyperparameter_idx], - num_beam_groups=args.num_beam_groups[hyperparameter_idx], - diversity_penalty=args.diversity_penalty[hyperparameter_idx], - no_repeat_ngram_size=args.no_repeat_ngram_size[hyperparameter_idx], - do_sample=args.temperature[hyperparameter_idx] != 0, - ) - - partial_batch_prediction_ids = generated.sequences - - partial_batch_prediction = numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] - - # post-process predictions - partial_batch_prediction = task.postprocess_prediction(turn_id, partial_batch_prediction) - - # put them into the right array - batch_prediction.append([partial_batch_prediction]) - - predictions += batch_prediction - - if train_target == 'dst': - # update dialogue_state - lev = predictions[-1][0].strip() - state_update = span2dict(lev, api_names) - for api_name in state_update: - active_api = api_name - if api_name not in dialogue_state: - dialogue_state[api_name] = state_update[api_name] - else: - dialogue_state[api_name].update(state_update[api_name]) - - #### save latest state - state_to_record = copy.deepcopy(dialogue_state) - state_to_record = {r_en_API_MAP.get(k, k): v for k, v in state_to_record.items()} - bitod_preds[dial_id]["turns"][str(turn_id)]["state"] = state_to_record - #### - - elif train_target == 'api': - new_knowledge_text = 'null' - constraints = {} - - api_name = active_api if active_api else 'None' - - do_api_call = predictions[-1][0].strip() - - if do_api_call == 'yes': - # make api call if required - api_name = active_api - # do api call - if api_name in dialogue_state: - constraints = state2api(dialogue_state[api_name]) - - try: - msg = api.call_api( - r_en_API_MAP.get(api_name, api_name), - constraints=[constraints], - ) - except Exception as e: - logger.error(f'Error: {e}') - logger.error(f'Failed API call with api_name: {api_name} and constraints: {constraints}') - msg = [0, 0] - - domain = api_name.split(" ")[0] - - knowledge = defaultdict(dict) - if int(msg[1]) <= 0: - new_knowledge_text = f'( {domain} ) Message = No item available.' - else: - # why does it only choose the first; does the same happen for training data? - knowledge[domain].update(msg[0]) - new_knowledge_text = knowledge2span(knowledge) - - elif do_api_call == 'no': - # do nothing - pass - else: - logger.error( - f'API call should be either yes or no but got {do_api_call}; seems model is still training so we assume it\'s a no' - ) - - #### save latest api results and constraints - bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text - bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints) - #### - - if train_target == 'response': - #### save latest response - bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = predictions[-1] - #### - - convo_history.append(predictions[-1][0]) - - except KeyboardInterrupt: - break - - with open(f"{int(time.time())}_bitod_preds.json", 'w') as fout: - ujson.dump(bitod_preds, fout, indent=2, ensure_ascii=False) - - if original_order is not None: - # sort back to the original order - original_order, example_ids, predictions, answers, contexts = [ - list(a) for a in tuple(zip(*sorted(list(zip(original_order, example_ids, predictions, answers, contexts))))) - ] - - # TODO calculate and return loss - loss = None - output = GenerationOutput(loss=loss) - - if output_predictions_only: - output.predictions = predictions - else: - output.example_ids, output.predictions, output.answers, output.contexts = example_ids, predictions, answers, contexts - - return output - - def generate_with_seq2seq_model_for_dialogue( model, data_iterator, @@ -384,6 +156,8 @@ def generate_with_seq2seq_model_for_dialogue( required_slots = {API_MAP[k]: v for k, v in required_slots.items()} api_names = list(required_slots.keys()) + special_tokens = numericalizer._tokenizer.all_special_tokens + for k, turn in enumerate(progress_bar(data_iterator, desc='Generating', disable=disable_progbar)): batch_size = len(turn.example_id) assert batch_size == 1 @@ -404,7 +178,6 @@ def generate_with_seq2seq_model_for_dialogue( active_api = None bitod_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)} - special_tokens = numericalizer._tokenizer.all_special_tokens batch_tokens = numericalizer.convert_ids_to_tokens(turn.context.value.data, skip_special_tokens=False) batch_context = [] # remove only beginning and trailing special tokens @@ -434,25 +207,6 @@ def generate_with_seq2seq_model_for_dialogue( # we always use gold history following common practice - # if first_turn: - # # first turn is always dst - # assert train_target == 'dst' - # numericalized_turn = NumericalizedExamples( - # example_id=[turn.example_id[0]], - # context=SequentialField( - # value=turn.context.value[[0]], - # length=turn.context.length[[0]], - # limited=turn.context.limited[[0]], - # feature=None, - # ), - # answer=SequentialField( - # value=turn.answer.value[[0]], - # length=turn.answer.value[[0]], - # limited=turn.answer.value[[0]], - # feature=None, - # ), - # ) - # else: if train_target == 'dst': input_text = replace_match(contexts[-1], state_re, new_state_text) @@ -486,12 +240,7 @@ def generate_with_seq2seq_model_for_dialogue( limited=torch.tensor([tokenized_contexts.limited], device=device), feature=None, ), - answer=SequentialField( - value=turn.answer.value[[0]], - length=turn.answer.value[[0]], - limited=turn.answer.value[[0]], - feature=None, - ), + answer=SequentialField(value=None, length=None, limited=None, feature=None), ) generated = model.generate( @@ -523,7 +272,7 @@ def generate_with_seq2seq_model_for_dialogue( if train_target == 'dst': # update dialogue_state lev = predictions[-1][0].strip() - state_update = span2dict(lev, api_names) + state_update = span2state(lev, api_names) for api_name in state_update: active_api = api_name if api_name not in dialogue_state: @@ -546,11 +295,13 @@ def generate_with_seq2seq_model_for_dialogue( do_api_call = predictions[-1][0].strip() if do_api_call == 'yes': - # make api call if required + # make api call api_name = active_api # do api call if api_name in dialogue_state: - constraints = state2api(dialogue_state[api_name]) + constraints = state2constraints(dialogue_state[api_name]) + domain = api_name.split(" ")[0] + knowledge = defaultdict(dict) try: msg = api.call_api( @@ -560,17 +311,19 @@ def generate_with_seq2seq_model_for_dialogue( except Exception as e: logger.error(f'Error: {e}') logger.error( - f'Failed API call with api_name: {api_name} and constraints: {constraints} for turn: {dial_id}/{turn_id}' + f'Failed API call with api_name: {api_name}, constraints: {constraints}, processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' ) - msg = [0, 0] + msg = [0, 0, 0] - domain = api_name.split(" ")[0] - - knowledge = defaultdict(dict) if int(msg[1]) <= 0: logger.warning( - f'Message = No item available for api_name: {api_name} and constraints: {constraints} for turn: {dial_id}/{turn_id}' + f'Message = No item available for api_name: {api_name}, constraints: {constraints}, processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' ) + gold_dial_state = span2state(state_re.search(contexts[-1]).group(1).strip(), api_names) + logger.warning( + f'state_diff: {list(dictdiffer.diff(dialogue_state[api_name], gold_dial_state[api_name]))}' + ) + new_knowledge_text = f'( {domain} ) Message = No item available.' else: # why does it only choose the first; does the same happen for training data? diff --git a/setup.py b/setup.py index f469c370..ee4cc346 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ 'datasets==1.12.1', 'seqeval==1.2.2', 'transformers==4.10.2', + 'dictdiffer~=0.9', 'sentencepiece==0.1.*', 'sentence-transformers==2.0.0', 'mosestokenizer~=1.1', From 99e7c3cf7fb2aaa3d3d10151a0914e83b594dfc2 Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 30 Aug 2021 14:58:04 -0700 Subject: [PATCH 021/102] validate: more cleanups --- genienlp/validate.py | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index 902f4a01..81477bae 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -148,6 +148,9 @@ def generate_with_seq2seq_model_for_dialogue( answers = [] contexts = [] + # TODO: handle multiple responses + hyperparameter_idx = 0 + cur_dial_id = '' device = model.device @@ -178,8 +181,9 @@ def generate_with_seq2seq_model_for_dialogue( active_api = None bitod_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)} - batch_tokens = numericalizer.convert_ids_to_tokens(turn.context.value.data, skip_special_tokens=False) batch_context = [] + batch_tokens = numericalizer.convert_ids_to_tokens(turn.context.value.data, skip_special_tokens=False) + # remove only beginning and trailing special tokens # otherwise the sep_token added between context and question will be lost for text in batch_tokens: @@ -202,14 +206,10 @@ def generate_with_seq2seq_model_for_dialogue( ] answers += batch_answer - # iterate through turns - hyperparameter_idx = 0 - - # we always use gold history following common practice - if train_target == 'dst': input_text = replace_match(contexts[-1], state_re, new_state_text) + ## we always use gold history following common practice ## if you want to use predicted response instead of gold uncomment the following # last_sys_pred = predictions[-1][0].strip() # input_text = replace_match(input_text, last_system_re, last_sys_pred) @@ -257,6 +257,7 @@ def generate_with_seq2seq_model_for_dialogue( no_repeat_ngram_size=args.no_repeat_ngram_size[hyperparameter_idx], do_sample=args.temperature[hyperparameter_idx] != 0, ) + partial_batch_prediction_ids = generated.sequences partial_batch_prediction = numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] @@ -288,36 +289,31 @@ def generate_with_seq2seq_model_for_dialogue( elif train_target == 'api': new_knowledge_text = 'null' - constraints = {} - - api_name = active_api if active_api else 'None' - do_api_call = predictions[-1][0].strip() if do_api_call == 'yes': # make api call api_name = active_api - # do api call + if api_name in dialogue_state: constraints = state2constraints(dialogue_state[api_name]) domain = api_name.split(" ")[0] knowledge = defaultdict(dict) try: - msg = api.call_api( - r_en_API_MAP.get(api_name, api_name), - constraints=[constraints], - ) + msg = api.call_api(r_en_API_MAP.get(api_name, api_name), constraints=[constraints]) except Exception as e: logger.error(f'Error: {e}') logger.error( - f'Failed API call with api_name: {api_name}, constraints: {constraints}, processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' + f'Failed API call with api_name: {api_name}, constraints: {constraints},' + f' processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' ) msg = [0, 0, 0] if int(msg[1]) <= 0: logger.warning( - f'Message = No item available for api_name: {api_name}, constraints: {constraints}, processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' + f'Message = No item available for api_name: {api_name}, constraints: {constraints},' + f' processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' ) gold_dial_state = span2state(state_re.search(contexts[-1]).group(1).strip(), api_names) logger.warning( @@ -326,21 +322,24 @@ def generate_with_seq2seq_model_for_dialogue( new_knowledge_text = f'( {domain} ) Message = No item available.' else: - # why does it only choose the first; does the same happen for training data? + # always choose highest ranking results (having deterministic api results) knowledge[domain].update(msg[0]) new_knowledge_text = knowledge2span(knowledge) + #### save latest api constraints + bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints) + #### + elif do_api_call == 'no': # do nothing pass else: logger.error( - f'API call should be either yes or no but got {do_api_call}; seems model is still training so we assume it\'s a no' + f'API call should be either yes or no but got {do_api_call}. Seems model is still training so we assume it\'s a no' ) - #### save latest api results and constraints + #### save latest api results bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text - bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints) #### if train_target == 'response': From 5fa59402100f0909c2e4a8bfcd4a27f061c5c422 Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 30 Aug 2021 17:19:34 -0700 Subject: [PATCH 022/102] dial_validate: update --- genienlp/dial_validate.py | 65 +++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 36 deletions(-) diff --git a/genienlp/dial_validate.py b/genienlp/dial_validate.py index 17f1b64f..249e74ac 100644 --- a/genienlp/dial_validate.py +++ b/genienlp/dial_validate.py @@ -26,6 +26,7 @@ def generate_with_seq2seq_model_for_dialogue_interactive( original_order=None, disable_progbar=True, ) -> GenerationOutput: + from termcolor import colored bitod_preds = dict() @@ -51,28 +52,28 @@ def generate_with_seq2seq_model_for_dialogue_interactive( convo_history = [] convo_window = 3 + hyperparameter_idx = 0 + train_target = 'response' next_target = {'dst': 'api', 'api': 'response', 'response': 'dst'} while True: try: - batch_prediction = [] # becomes dst for first turn train_target = next_target[train_target] - hyperparameter_idx = 0 - if train_target == 'dst': if convo_history: - print(convo_history[-1]) + print(colored(f'SYSTEM: {convo_history[-1]}', 'red', attrs=['bold'])) else: - print('Hello! What are you looking for today?') + print(colored('SYSTEM: Hello! What are you looking for today?', 'red', attrs=['bold'])) + # Hello, I am looking for a restaurant with Vegan Options. # construct new input - raw_user_input = input('USER: ') + raw_user_input = input(colored('USER: ', 'green', attrs=['bold'])) if raw_user_input == 'RESET': generate_with_seq2seq_model_for_dialogue_interactive( model, @@ -86,6 +87,8 @@ def generate_with_seq2seq_model_for_dialogue_interactive( break elif raw_user_input == 'END': sys.exit(0) + elif raw_user_input == 'STATE': + print(f'dialogue state: {dialogue_state}') raw_user_input = 'USER: ' + raw_user_input.strip() @@ -116,12 +119,7 @@ def generate_with_seq2seq_model_for_dialogue_interactive( limited=torch.tensor([tokenized_contexts.limited], device=device), feature=None, ), - answer=SequentialField( - value=None, - length=None, - limited=None, - feature=None, - ), + answer=SequentialField(value=None, length=None, limited=None, feature=None), ) generated = model.generate( @@ -170,50 +168,45 @@ def generate_with_seq2seq_model_for_dialogue_interactive( elif train_target == 'api': new_knowledge_text = 'null' - constraints = {} - - api_name = active_api if active_api else 'None' - do_api_call = predictions[-1][0].strip() if do_api_call == 'yes': - # make api call if required + # make api call api_name = active_api - # do api call + if api_name in dialogue_state: constraints = state2constraints(dialogue_state[api_name]) + domain = api_name.split(" ")[0] + knowledge = defaultdict(dict) try: - msg = api.call_api( - r_en_API_MAP.get(api_name, api_name), - constraints=[constraints], - ) + msg = api.call_api(r_en_API_MAP.get(api_name, api_name), constraints=[constraints]) except Exception as e: logger.error(f'Error: {e}') - logger.error(f'Failed API call with api_name: {api_name} and constraints: {constraints}') - msg = [0, 0] - - domain = api_name.split(" ")[0] + logger.error( + f'Failed API call with api_name: {api_name}, constraints: {constraints},' + f' processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' + ) + msg = [0, 0, 0] - knowledge = defaultdict(dict) if int(msg[1]) <= 0: + logger.warning( + f'Message = No item available for api_name: {api_name}, constraints: {constraints},' + f' processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' + ) + new_knowledge_text = f'( {domain} ) Message = No item available.' else: - # why does it only choose the first; does the same happen for training data? + # always choose highest ranking results (having deterministic api results) knowledge[domain].update(msg[0]) new_knowledge_text = knowledge2span(knowledge) - elif do_api_call == 'no': - # do nothing - pass - else: - logger.error( - f'API call should be either yes or no but got {do_api_call}; seems model is still training so we assume it\'s a no' - ) + #### save latest api constraints + bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints) + #### #### save latest api results and constraints bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text - bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints) #### if train_target == 'response': From f4d00df9bab0bc69ac009ea6c3d25eb4dc51228c Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 30 Aug 2021 18:48:42 -0700 Subject: [PATCH 023/102] bitod: update special tokens for v7 --- genienlp/tasks/generic_task.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index eebdf7ec..9b3fb1a5 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -423,7 +423,8 @@ def __init__(self, name, args): 'Response:', } special_tokens_v5 = {'AGENT_ACTS:'} - self.special_tokens = special_tokens_v1 | special_tokens_v2 | special_tokens_v5 + special_tokens_v7 = {'ACTS:'} + self.special_tokens = special_tokens_v1 | special_tokens_v2 | special_tokens_v5 | special_tokens_v7 self._metrics = 'casedbleu' def utterance_field(self): From 3491060b8e76a5f0d0b76f40eb2947feb8cd356c Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 1 Sep 2021 14:22:29 -0700 Subject: [PATCH 024/102] align_and_replace: fix bug for marian models for single language marian models src_lang and tgt_lang will be set to None in "adjust_language_code" --- genienlp/model_utils/translation.py | 19 ++++++++++++------- genienlp/models/transformer_seq2seq.py | 3 +++ genienlp/tasks/almond_task.py | 9 ++++++++- genienlp/validate.py | 7 ++++++- 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/genienlp/model_utils/translation.py b/genienlp/model_utils/translation.py index e87a8699..a212eb47 100644 --- a/genienlp/model_utils/translation.py +++ b/genienlp/model_utils/translation.py @@ -85,7 +85,9 @@ def return_token_word_mapping(tokens, tokenizer): return token2word_mapping, word2token_span_mapping -def align_and_replace(src_tokens, tgt_tokens, tokenizer, sample_layer_attention_pooled, src_spans, remove_output_quotation): +def align_and_replace( + src_tokens, tgt_tokens, tokenizer, tgt_lang, sample_layer_attention_pooled, src_spans, remove_output_quotation +): src_quotation_symbol = '"' # M2M100Tokenizer has missing tokens in its fixed vocabulary and encodes them as unknown (https://github.com/pytorch/fairseq/issues/3463) @@ -121,12 +123,15 @@ def align_and_replace(src_tokens, tgt_tokens, tokenizer, sample_layer_attention_ expanded_matches = [cur_match] # translation turned digit into words - if ( - len(cur_match) == 1 - and cur_match[0].isdigit() - and (tokenizer.tgt_lang in CONVERTER_CLASSES or tokenizer.tgt_lang[:2] in CONVERTER_CLASSES) - ): - expanded_matches.append([num2words(cur_match[0], lang=tokenizer.tgt_lang, to='cardinal')]) + if len(cur_match) == 1 and cur_match[0].isdigit(): + converter = None + if tgt_lang in CONVERTER_CLASSES: + converter = CONVERTER_CLASSES[tgt_lang] + elif tgt_lang[:2] in CONVERTER_CLASSES: + converter = CONVERTER_CLASSES[tgt_lang[:2]] + + if converter and hasattr(converter, 'str_to_number'): + expanded_matches.append([num2words(cur_match[0], lang=tgt_lang, to='cardinal')]) for match in expanded_matches: count, beg_indices = count_substring(tgt_words, match) diff --git a/genienlp/models/transformer_seq2seq.py b/genienlp/models/transformer_seq2seq.py index 7af8b5b3..056c49ef 100644 --- a/genienlp/models/transformer_seq2seq.py +++ b/genienlp/models/transformer_seq2seq.py @@ -58,6 +58,9 @@ def __init__(self, config=None, *inputs, args, tasks, vocab_sets, save_directory if tasks: self.set_task_dependent_generation_kwargs(tasks) + # only used for Marian models. adjusted language codes passed to numericalizer will be None for models trained on single langauge pairs + self.orig_src_lang, self.orig_tgt_lang = kwargs.get('src_lang', 'en'), kwargs.get('tgt_lang', 'en') + self.src_lang, self.tgt_lang = adjust_language_code( config, args.pretrained_model, kwargs.get('src_lang', 'en'), kwargs.get('tgt_lang', 'en') ) diff --git a/genienlp/tasks/almond_task.py b/genienlp/tasks/almond_task.py index 37491a9e..4eb8b5ca 100644 --- a/genienlp/tasks/almond_task.py +++ b/genienlp/tasks/almond_task.py @@ -381,6 +381,7 @@ def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, bat numericalizer = kwargs.pop('numericalizer') cross_attentions = kwargs.pop('cross_attentions') + tgt_lang = kwargs.pop('tgt_lang') num_outputs = len(batch_tgt_ids) // len(batch_src_ids) # TODO _tokenizer should not be private @@ -448,7 +449,13 @@ def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, bat if self.args.do_alignment: src_spans = self.input_spans[example_id] text = align_and_replace( - src_tokens, tgt_tokens, tokenizer, cross_att, src_spans, self.args.align_remove_output_quotation + src_tokens, + tgt_tokens, + tokenizer, + tgt_lang, + cross_att, + src_spans, + self.args.align_remove_output_quotation, ) else: text = tokenizer.convert_tokens_to_string(tgt_tokens) diff --git a/genienlp/validate.py b/genienlp/validate.py index 81477bae..5d4f65c6 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -398,6 +398,11 @@ def generate_with_seq2seq_model( answers = [] contexts = [] + if numericalizer._tokenizer.tgt_lang: + tgt_lang = numericalizer._tokenizer.tgt_lang + else: + tgt_lang = model.orig_tgt_lang + for batch in progress_bar(data_iterator, desc='Generating', disable=disable_progbar): batch_size = len(batch.example_id) batch_prediction = [[] for _ in range(batch_size)] @@ -454,7 +459,7 @@ def generate_with_seq2seq_model( cross_attentions = cross_attentions[-1, ...] # postprocess prediction ids - kwargs = {'numericalizer': numericalizer, 'cross_attentions': cross_attentions} + kwargs = {'numericalizer': numericalizer, 'cross_attentions': cross_attentions, 'tgt_lang': tgt_lang} partial_batch_prediction_ids = task.batch_postprocess_prediction_ids( batch_example_ids, batch.context.value.data, partial_batch_prediction_ids, **kwargs ) From 6d2029528b74599939f2f375dd0c871547bf863d Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 1 Sep 2021 23:31:17 -0700 Subject: [PATCH 025/102] Update knowledge representation --- genienlp/dial_validate.py | 7 ++++--- genienlp/validate.py | 9 +++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/genienlp/dial_validate.py b/genienlp/dial_validate.py index 249e74ac..2b51c720 100644 --- a/genienlp/dial_validate.py +++ b/genienlp/dial_validate.py @@ -89,6 +89,7 @@ def generate_with_seq2seq_model_for_dialogue_interactive( sys.exit(0) elif raw_user_input == 'STATE': print(f'dialogue state: {dialogue_state}') + continue raw_user_input = 'USER: ' + raw_user_input.strip() @@ -176,7 +177,7 @@ def generate_with_seq2seq_model_for_dialogue_interactive( if api_name in dialogue_state: constraints = state2constraints(dialogue_state[api_name]) - domain = api_name.split(" ")[0] + # domain = api_name.split(" ")[0] knowledge = defaultdict(dict) try: @@ -195,10 +196,10 @@ def generate_with_seq2seq_model_for_dialogue_interactive( f' processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' ) - new_knowledge_text = f'( {domain} ) Message = No item available.' + new_knowledge_text = f'( {api_name} ) Message = No item available.' else: # always choose highest ranking results (having deterministic api results) - knowledge[domain].update(msg[0]) + knowledge[api_name].update(msg[0]) new_knowledge_text = knowledge2span(knowledge) #### save latest api constraints diff --git a/genienlp/validate.py b/genienlp/validate.py index 5d4f65c6..af5b73c8 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -37,9 +37,10 @@ import dictdiffer import torch import ujson -from BiToD.evaluate import r_en_API_MAP, span2state, state2constraints +from BiToD.evaluate import r_en_API_MAP from BiToD.knowledgebase import api from BiToD.preprocess import API_MAP, knowledge2span, read_require_slots, state2span +from BiToD.utils import span2state, state2constraints from .data_utils.example import NumericalizedExamples, SequentialField from .data_utils.progbar import progress_bar @@ -297,7 +298,7 @@ def generate_with_seq2seq_model_for_dialogue( if api_name in dialogue_state: constraints = state2constraints(dialogue_state[api_name]) - domain = api_name.split(" ")[0] + # domain = api_name.split(" ")[0] knowledge = defaultdict(dict) try: @@ -320,10 +321,10 @@ def generate_with_seq2seq_model_for_dialogue( f'state_diff: {list(dictdiffer.diff(dialogue_state[api_name], gold_dial_state[api_name]))}' ) - new_knowledge_text = f'( {domain} ) Message = No item available.' + new_knowledge_text = f'( {api_name} ) Message = No item available.' else: # always choose highest ranking results (having deterministic api results) - knowledge[domain].update(msg[0]) + knowledge[api_name].update(msg[0]) new_knowledge_text = knowledge2span(knowledge) #### save latest api constraints From f0b3ed068483a13a6a0b3471d4f7c9325118cb58 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 2 Sep 2021 11:24:24 -0700 Subject: [PATCH 026/102] Fix alignment for marian models --- genienlp/tasks/almond_task.py | 12 ++++++------ genienlp/tasks/base_task.py | 2 +- genienlp/validate.py | 21 ++++++++++++++------- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/genienlp/tasks/almond_task.py b/genienlp/tasks/almond_task.py index 4eb8b5ca..0193fd8f 100644 --- a/genienlp/tasks/almond_task.py +++ b/genienlp/tasks/almond_task.py @@ -93,7 +93,7 @@ def get_splits(self, root, **kwargs): return AlmondDataset.return_splits(path=os.path.join(root, 'almond'), make_example=self._make_example, **kwargs) def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs): - return batch_tgt_ids + return batch_tgt_ids, None def postprocess_prediction(self, example_id, prediction): @@ -462,12 +462,12 @@ def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, bat all_text_outputs.append(text) - with tokenizer.as_target_tokenizer(): - partial_batch_prediction_ids = tokenizer.batch_encode_plus( - all_text_outputs, padding=True, return_tensors='pt' - )['input_ids'] + with tokenizer.as_target_tokenizer(): + partial_batch_prediction_ids = tokenizer.batch_encode_plus(all_text_outputs, padding=True, return_tensors='pt')[ + 'input_ids' + ] - return partial_batch_prediction_ids + return partial_batch_prediction_ids, all_text_outputs def _make_example(self, parts, dir_name=None, **kwargs): # answer has to be provided by default unless doing prediction diff --git a/genienlp/tasks/base_task.py b/genienlp/tasks/base_task.py index 70cb0850..b88ac2cf 100644 --- a/genienlp/tasks/base_task.py +++ b/genienlp/tasks/base_task.py @@ -71,7 +71,7 @@ def get_splits(self, root, **kwargs): return generic_dataset.JSON.splits(root=root, name=self.name, **kwargs) def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs): - return batch_tgt_ids + return batch_tgt_ids, None def postprocess_prediction(self, example_id, prediction): return prediction diff --git a/genienlp/validate.py b/genienlp/validate.py index af5b73c8..c615b256 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -41,6 +41,7 @@ from BiToD.knowledgebase import api from BiToD.preprocess import API_MAP, knowledge2span, read_require_slots, state2span from BiToD.utils import span2state, state2constraints +from transformers import MarianTokenizer from .data_utils.example import NumericalizedExamples, SequentialField from .data_utils.progbar import progress_bar @@ -461,16 +462,22 @@ def generate_with_seq2seq_model( # postprocess prediction ids kwargs = {'numericalizer': numericalizer, 'cross_attentions': cross_attentions, 'tgt_lang': tgt_lang} - partial_batch_prediction_ids = task.batch_postprocess_prediction_ids( + partial_batch_prediction_ids, partial_batch_words = task.batch_postprocess_prediction_ids( batch_example_ids, batch.context.value.data, partial_batch_prediction_ids, **kwargs ) - if output_confidence_features or output_confidence_scores: - partial_batch_confidence_features = model.confidence_features( - batch=batch, predictions=partial_batch_prediction_ids, mc_dropout_num=args.mc_dropout_num - ) - - partial_batch_prediction = numericalizer.reverse(partial_batch_prediction_ids, 'answer') + # MarianTokenizer uses two different spm models for encoding source vs target language. + # in almond_translate we postprocess text with alignment which gives code-switched sentences. + # encoding a code-switched sentence with either spm will omit tokens from the other language + # so now we will return both the actual processed text and the encoded version + if isinstance(numericalizer._tokenizer, MarianTokenizer): + partial_batch_prediction = partial_batch_words + else: + if output_confidence_features or output_confidence_scores: + partial_batch_confidence_features = model.confidence_features( + batch=batch, predictions=partial_batch_prediction_ids, mc_dropout_num=args.mc_dropout_num + ) + partial_batch_prediction = numericalizer.reverse(partial_batch_prediction_ids, 'answer') def get_example_index(i): return (i // args.num_outputs[hyperparameter_idx]) % batch_size From 54f6689843352aa22022adbc4ddc8efb52561351 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 2 Sep 2021 11:51:33 -0700 Subject: [PATCH 027/102] Fix text splitting for cjk langs --- genienlp/tasks/almond_task.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/genienlp/tasks/almond_task.py b/genienlp/tasks/almond_task.py index 0193fd8f..86418e7c 100644 --- a/genienlp/tasks/almond_task.py +++ b/genienlp/tasks/almond_task.py @@ -289,18 +289,20 @@ def _make_example(self, parts, dir_name=None, **kwargs): def inside_spans(start, spans): + if not spans: + return False for span in spans: if span[0] <= start < span[1]: return True return False -def return_sentences(text, regex_pattern, src_char_spans, lang): +def return_sentences(text, regex_pattern, src_char_spans, is_cjk=False): sentences = [] cur = 0 for m in re.finditer(regex_pattern, text, flags=re.U): - if src_char_spans and not inside_spans(m.start(0), src_char_spans): - sentences.append(text[cur : m.start(0) + (1 if lang in ['zh', 'ja', 'ko'] else 0)]) + if not inside_spans(m.start(0), src_char_spans): + sentences.append(text[cur : m.start(0) + (1 if is_cjk else 0)]) cur = m.end(0) if cur != len(text): sentences.append(text[cur:]) @@ -309,9 +311,9 @@ def return_sentences(text, regex_pattern, src_char_spans, lang): def split_text_into_sentences(text, lang, src_char_spans): if lang in ['en']: - sentences = return_sentences(text, '(? Date: Mon, 13 Sep 2021 14:30:18 -0700 Subject: [PATCH 028/102] util: fix minor bug --- genienlp/util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/genienlp/util.py b/genienlp/util.py index 83ef67b6..3e27d53b 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -844,7 +844,8 @@ def load_config_json(args): elif r in ( 'do_ned', 'do_alignment', - 'align_preserve_input_quotation' 'align_remove_output_quotation', + 'align_preserve_input_quotation', + 'align_remove_output_quotation', 'use_encoder_loss', 'almond_has_multiple_programs', 'almond_lang_as_question', From 2d23d27c75ac73afcc3807ee4a92dfaa2288c19c Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 13 Sep 2021 14:38:11 -0700 Subject: [PATCH 029/102] validate: fix bug if alignment is not used with Marian --- genienlp/validate.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index c615b256..c4d99c57 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -445,6 +445,7 @@ def generate_with_seq2seq_model( do_sample=args.temperature[hyperparameter_idx] != 0, # if temperature==0, we do not sample ) partial_batch_prediction_ids = generated.sequences + partial_batch_words = None if model._output_attentions: cross_attentions = generated.cross_attentions @@ -466,11 +467,12 @@ def generate_with_seq2seq_model( batch_example_ids, batch.context.value.data, partial_batch_prediction_ids, **kwargs ) - # MarianTokenizer uses two different spm models for encoding source vs target language. - # in almond_translate we postprocess text with alignment which gives code-switched sentences. + # MarianTokenizer uses two different spm models for encoding source and target languages. + # in almond_translate we postprocess text with alignment which produces code-switched sentences. # encoding a code-switched sentence with either spm will omit tokens from the other language - # so now we will return both the actual processed text and the encoded version - if isinstance(numericalizer._tokenizer, MarianTokenizer): + # so we have to return both the processed and encoded text. + # we need to return encoded text too since confidence_features requires ids + if isinstance(numericalizer._tokenizer, MarianTokenizer) and partial_batch_words: partial_batch_prediction = partial_batch_words else: if output_confidence_features or output_confidence_scores: From 3744789d56e04c9c3cec9085cb2860c64a1802ac Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 13 Sep 2021 14:49:36 -0700 Subject: [PATCH 030/102] Add option to return raw translation outputs --- genienlp/predict.py | 41 +++++++++++++++++++++++++++++------------ genienlp/util.py | 33 +++++++++++++++++++++++++++------ genienlp/validate.py | 39 ++++++++++++++++++++++++++++++++++++--- 3 files changed, 92 insertions(+), 21 deletions(-) diff --git a/genienlp/predict.py b/genienlp/predict.py index 5fc07e67..270d7c3e 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -219,6 +219,12 @@ def parse_argv(parser): help='split examples with multiple sentences into individual examples', ) + parser.add_argument( + '--translate_return_raw_outputs', + action='store_true', + help='return raw translation as well as ones post-processed with alignment. this is useful for STS filtering.', + ) + parser.add_argument('--plot_heatmaps', action='store_true', help='whether to plot cross-attention heatmaps') parser.add_argument( '--do_alignment', @@ -275,7 +281,7 @@ def check_args(args): ) -def prepare_data(args, device, src_lang): +def prepare_data(args, src_lang): datasets = [] paths = [] @@ -391,7 +397,7 @@ def run(args, device): tgt_lang=tgt_lang, ) - val_sets = prepare_data(args, device, src_lang) + val_sets = prepare_data(args, src_lang) model.add_new_vocab_from_data(args.tasks) iters = prepare_data_iterators(args, val_sets, model.numericalizer, device) @@ -412,21 +418,20 @@ def run(args, device): # single language task if language is None or 'multilingual' not in task.name: prediction_file_name = os.path.join(eval_dir, task.name + '.tsv') + raw_prediction_file_name = os.path.join(eval_dir, task.name + '.raw.tsv') results_file_name = os.path.join(eval_dir, task.name + '.results.json') # multi language task else: prediction_file_name = os.path.join(eval_dir, task.name + '_{}.tsv'.format(language)) + raw_prediction_file_name = os.path.join(eval_dir, task.name + '_{}.raw.tsv'.format(language)) results_file_name = os.path.join(eval_dir, task.name + '_{}.results.json'.format(language)) - if os.path.exists(prediction_file_name): - if args.overwrite: - logger.warning(f'{prediction_file_name} already exists -- overwriting **') - else: - raise OSError(f'{prediction_file_name} already exists') - if os.path.exists(results_file_name): - if args.overwrite: - logger.warning(f'{results_file_name} already exists -- overwriting **') - else: - raise OSError(f'{results_file_name} already exists') + + for fname in [prediction_file_name, raw_prediction_file_name, results_file_name]: + if os.path.exists(fname): + if args.overwrite: + logger.warning(f'{fname} already exists -- overwriting **') + else: + raise OSError(f'{fname} already exists') if args.calibrator_paths is not None: confidence_estimators = [] @@ -469,6 +474,18 @@ def run(args, device): line += '\t' + str(score[i]) prediction_file.write(line + '\n') + if args.translate_return_raw_outputs: + with open(raw_prediction_file_name, 'w' + ('' if args.overwrite else '+')) as prediction_file: + for i in range(len(generation_output.example_ids)): + line = ( + generation_output.example_ids[i] + + '\t' + + '\t'.join(generation_output.raw_predictions[i]) + + '\t' + + generation_output.answers[i] + ) # all outputs separated by '\t' + prediction_file.write(line + '\n') + if len(generation_output.answers) > 0: metrics_to_compute = task.metrics metrics_to_compute += args.extra_metrics diff --git a/genienlp/util.py b/genienlp/util.py index 3e27d53b..55eacc17 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -240,6 +240,7 @@ def __init__( loss: Optional[float] = None, example_ids: Optional[List] = None, predictions: Optional[List] = None, + raw_predictions: Optional[List] = None, answers: Optional[List] = None, contexts: Optional[List] = None, confidence_features: Optional[List] = None, @@ -248,6 +249,7 @@ def __init__( self.loss = loss self.example_ids = example_ids self.predictions = predictions + self.raw_predictions = raw_predictions self.answers = answers self.contexts = contexts self.confidence_features = confidence_features @@ -638,30 +640,42 @@ def ned_dump_entity_type_pairs(dataset, path, name, utterance_field): fout.write(ujson.dumps({"sentence": sentence, "aliases": entities, "thingtalk_types": ent_types}) + '\n') -def merge_translated_sentences(example_ids, predictions, answers, contexts, confidence_features, src_lang, tgt_lang): +def merge_translated_sentences( + example_ids, predictions, raw_predictions, answers, contexts, confidence_features, src_lang, tgt_lang +): new_example_ids = [] new_predictions = [] + new_raw_predictions = [] new_answers = [] new_contexts = [] new_confidence_features = [] - cur_pred, cur_context, cur_answer = [], [], [] + cur_pred, cur_raw_pred, cur_context, cur_answer = [], [], [], [] i = 0 src_concat_token = '' if src_lang in ['zh', 'ja', 'ko'] else ' ' tgt_concat_token = '' if tgt_lang in ['zh', 'ja', 'ko'] else ' ' while i < len(predictions): - ex_id, pred, ans, ctxt, cf_feat = example_ids[i], predictions[i], answers[i], contexts[i], confidence_features[i] + ex_id, pred, raw_pred, ans, ctxt, cf_feat = ( + example_ids[i], + predictions[i], + raw_predictions[i], + answers[i], + contexts[i], + confidence_features[i], + ) if '@' in ex_id: id_, split_id = ex_id.rsplit('@', 1) cur_id = id_ while id_ == cur_id: cur_pred.append(pred) + cur_raw_pred.append(raw_pred) cur_context.append(ctxt) cur_answer.append(ans) i += 1 if i < len(predictions): - ex_id, pred, ans, ctxt, cf_feat = ( + ex_id, pred, raw_pred, ans, ctxt, cf_feat = ( example_ids[i], predictions[i], + raw_predictions[i], answers[i], contexts[i], confidence_features[i], @@ -677,22 +691,29 @@ def merge_translated_sentences(example_ids, predictions, answers, contexts, conf new_predictions.append( [tgt_concat_token.join([cur_pred[j][0] for j in range(len(cur_pred))]) for i in range(len(cur_pred[0]))] ) + new_raw_predictions.append( + [ + tgt_concat_token.join([cur_raw_pred[j][0] for j in range(len(cur_raw_pred))]) + for i in range(len(cur_raw_pred[0])) + ] + ) new_contexts.append(src_concat_token.join(cur_context)) new_answers.append(src_concat_token.join(cur_answer)) new_confidence_features.append(cf_feat) # reset block - cur_pred, cur_context, cur_answer = [], [], [] + cur_pred, cur_raw_pred, cur_context, cur_answer = [], [], [], [] else: new_example_ids.append(ex_id) new_predictions.append(pred) + new_raw_predictions.append(raw_pred) new_contexts.append(ctxt) new_answers.append(ans) new_confidence_features.append(cf_feat) i += 1 - return new_example_ids, new_predictions, new_answers, new_contexts, new_confidence_features + return new_example_ids, new_predictions, new_raw_predictions, new_answers, new_contexts, new_confidence_features def get_mbart_lang(orig_lang): diff --git a/genienlp/validate.py b/genienlp/validate.py index c4d99c57..6853d6fa 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -395,6 +395,7 @@ def generate_with_seq2seq_model( total_loss = 0.0 if model._output_scores else None output_confidence_scores = confidence_estimators is not None predictions = [] + raw_predictions = [] confidence_features = [] example_ids = [] answers = [] @@ -408,6 +409,7 @@ def generate_with_seq2seq_model( for batch in progress_bar(data_iterator, desc='Generating', disable=disable_progbar): batch_size = len(batch.example_id) batch_prediction = [[] for _ in range(batch_size)] + batch_raw_prediction = [[] for _ in range(batch_size)] batch_confidence_features = [[] for _ in range(batch_size)] batch_example_ids = batch.example_id @@ -463,6 +465,10 @@ def generate_with_seq2seq_model( # postprocess prediction ids kwargs = {'numericalizer': numericalizer, 'cross_attentions': cross_attentions, 'tgt_lang': tgt_lang} + + if args.translate_return_raw_outputs: + partial_batch_raw_prediction_ids = partial_batch_prediction_ids + partial_batch_prediction_ids, partial_batch_words = task.batch_postprocess_prediction_ids( batch_example_ids, batch.context.value.data, partial_batch_prediction_ids, **kwargs ) @@ -484,6 +490,15 @@ def generate_with_seq2seq_model( def get_example_index(i): return (i // args.num_outputs[hyperparameter_idx]) % batch_size + if args.translate_return_raw_outputs: + partial_batch_raw_prediction = numericalizer.reverse(partial_batch_raw_prediction_ids, 'answer') + for i in range(len(partial_batch_prediction)): + partial_batch_raw_prediction[i] = task.postprocess_prediction( + batch_example_ids[get_example_index(i)], partial_batch_raw_prediction[i] + ) + for i in range(len(partial_batch_prediction)): + batch_raw_prediction[get_example_index(i)].append(partial_batch_raw_prediction[i]) + # post-process predictions for i in range(len(partial_batch_prediction)): partial_batch_prediction[i] = task.postprocess_prediction( @@ -498,24 +513,40 @@ def get_example_index(i): predictions += batch_prediction confidence_features += batch_confidence_features + raw_predictions += batch_raw_prediction if total_loss is not None: total_loss /= len(example_ids) if original_order is not None: # sort back to the original order - original_order, example_ids, predictions, answers, contexts, confidence_features = [ + original_order, example_ids, predictions, raw_predictions, answers, contexts, confidence_features = [ list(a) for a in tuple( - zip(*sorted(list(zip(original_order, example_ids, predictions, answers, contexts, confidence_features)))) + zip( + *sorted( + list( + zip( + original_order, + example_ids, + predictions, + raw_predictions, + answers, + contexts, + confidence_features, + ) + ) + ) + ) ) ] if getattr(args, 'translate_example_split', False): # stitch sentences back together - example_ids, predictions, answers, contexts, confidence_features = merge_translated_sentences( + example_ids, predictions, raw_predictions, answers, contexts, confidence_features = merge_translated_sentences( example_ids, predictions, + raw_predictions, answers, contexts, confidence_features, @@ -540,6 +571,8 @@ def get_example_index(i): for estimator in confidence_estimators: confidence_scores = estimator.estimate(confidence_features) output.confidence_scores.append(confidence_scores) + if args.translate_return_raw_outputs: + output.raw_predictions = raw_predictions return output From 9265b3ecb8797992a68d92aa72944c78e99f1552 Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 13 Sep 2021 16:44:25 -0700 Subject: [PATCH 031/102] Use dateparser to reduce need of alignment for dates --- genienlp/arguments.py | 2 +- genienlp/model_utils/translation.py | 30 +++++++++++++++++++---------- genienlp/tasks/almond_task.py | 6 ++++-- genienlp/validate.py | 25 +++++++++++++++++++----- setup.py | 1 + 5 files changed, 46 insertions(+), 18 deletions(-) diff --git a/genienlp/arguments.py b/genienlp/arguments.py index 2dc81512..4e10bfff 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -673,6 +673,6 @@ def post_parse_train_specific(args): setattr(args, x, os.path.join(args.root, getattr(args, x))) save_args(args, force_overwrite=True) - + args = check_and_update_generation_args(args) return args diff --git a/genienlp/model_utils/translation.py b/genienlp/model_utils/translation.py index a212eb47..4adf103e 100644 --- a/genienlp/model_utils/translation.py +++ b/genienlp/model_utils/translation.py @@ -30,6 +30,7 @@ import numpy as np import torch +from dateparser.conf import Settings from num2words import CONVERTER_CLASSES, num2words from transformers import SPIECE_UNDERLINE, M2M100Tokenizer @@ -86,7 +87,14 @@ def return_token_word_mapping(tokens, tokenizer): def align_and_replace( - src_tokens, tgt_tokens, tokenizer, tgt_lang, sample_layer_attention_pooled, src_spans, remove_output_quotation + src_tokens, + tgt_tokens, + sample_layer_attention_pooled, + src_spans, + tgt_lang, + tokenizer, + remove_output_quotation, + date_parser=None, ): src_quotation_symbol = '"' @@ -123,15 +131,17 @@ def align_and_replace( expanded_matches = [cur_match] # translation turned digit into words - if len(cur_match) == 1 and cur_match[0].isdigit(): - converter = None - if tgt_lang in CONVERTER_CLASSES: - converter = CONVERTER_CLASSES[tgt_lang] - elif tgt_lang[:2] in CONVERTER_CLASSES: - converter = CONVERTER_CLASSES[tgt_lang[:2]] - - if converter and hasattr(converter, 'str_to_number'): - expanded_matches.append([num2words(cur_match[0], lang=tgt_lang, to='cardinal')]) + if ( + len(cur_match) == 1 + and cur_match[0].isdigit() + and (tgt_lang in CONVERTER_CLASSES or tgt_lang[:2] in CONVERTER_CLASSES) + ): + match = int(cur_match[0]) + expanded_matches.append([num2words(match, lang=tgt_lang, to='cardinal')]) + + # find translation of dates + elif date_parser: + expanded_matches.append(date_parser.translate(' '.join(cur_match), settings=Settings()).split(' ')) for match in expanded_matches: count, beg_indices = count_substring(tgt_words, match) diff --git a/genienlp/tasks/almond_task.py b/genienlp/tasks/almond_task.py index 86418e7c..c825735a 100644 --- a/genienlp/tasks/almond_task.py +++ b/genienlp/tasks/almond_task.py @@ -384,6 +384,7 @@ def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, bat numericalizer = kwargs.pop('numericalizer') cross_attentions = kwargs.pop('cross_attentions') tgt_lang = kwargs.pop('tgt_lang') + date_parser = kwargs.pop('date_parser') num_outputs = len(batch_tgt_ids) // len(batch_src_ids) # TODO _tokenizer should not be private @@ -453,11 +454,12 @@ def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, bat text = align_and_replace( src_tokens, tgt_tokens, - tokenizer, - tgt_lang, cross_att, src_spans, + tgt_lang, + tokenizer, self.args.align_remove_output_quotation, + date_parser=date_parser, ) else: text = tokenizer.convert_tokens_to_string(tgt_tokens) diff --git a/genienlp/validate.py b/genienlp/validate.py index 6853d6fa..35c3cabc 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -41,6 +41,7 @@ from BiToD.knowledgebase import api from BiToD.preprocess import API_MAP, knowledge2span, read_require_slots, state2span from BiToD.utils import span2state, state2constraints +from dateparser.languages import default_loader from transformers import MarianTokenizer from .data_utils.example import NumericalizedExamples, SequentialField @@ -406,6 +407,15 @@ def generate_with_seq2seq_model( else: tgt_lang = model.orig_tgt_lang + if numericalizer._tokenizer.src_lang: + src_lang = numericalizer._tokenizer.src_lang + else: + src_lang = model.orig_src_lang + + date_parser = default_loader.get_locale(src_lang[:2]) + + translate_return_raw_outputs = getattr(args, 'translate_return_raw_outputs', False) + for batch in progress_bar(data_iterator, desc='Generating', disable=disable_progbar): batch_size = len(batch.example_id) batch_prediction = [[] for _ in range(batch_size)] @@ -464,9 +474,14 @@ def generate_with_seq2seq_model( cross_attentions = cross_attentions[-1, ...] # postprocess prediction ids - kwargs = {'numericalizer': numericalizer, 'cross_attentions': cross_attentions, 'tgt_lang': tgt_lang} - - if args.translate_return_raw_outputs: + kwargs = { + 'numericalizer': numericalizer, + 'cross_attentions': cross_attentions, + 'tgt_lang': tgt_lang, + 'date_parser': date_parser, + } + + if translate_return_raw_outputs: partial_batch_raw_prediction_ids = partial_batch_prediction_ids partial_batch_prediction_ids, partial_batch_words = task.batch_postprocess_prediction_ids( @@ -490,7 +505,7 @@ def generate_with_seq2seq_model( def get_example_index(i): return (i // args.num_outputs[hyperparameter_idx]) % batch_size - if args.translate_return_raw_outputs: + if translate_return_raw_outputs: partial_batch_raw_prediction = numericalizer.reverse(partial_batch_raw_prediction_ids, 'answer') for i in range(len(partial_batch_prediction)): partial_batch_raw_prediction[i] = task.postprocess_prediction( @@ -571,7 +586,7 @@ def get_example_index(i): for estimator in confidence_estimators: confidence_scores = estimator.estimate(confidence_features) output.confidence_scores.append(confidence_scores) - if args.translate_return_raw_outputs: + if translate_return_raw_outputs: output.raw_predictions = raw_predictions return output diff --git a/setup.py b/setup.py index ee4cc346..c01b1993 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,7 @@ 'bert-score~=0.3', 'requests~=2.22', 'num2words>=0.5.10', + 'dateparser>=1.0.0', 'termcolor~=1.1', 'datasets==1.12.1', 'seqeval==1.2.2', From 0a99ef0884d156bcf47eabd24384af56a955c932 Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 14 Sep 2021 15:21:22 -0700 Subject: [PATCH 032/102] bitod: add special tokens for v9 --- genienlp/tasks/generic_task.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 9b3fb1a5..0c85fa33 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -424,7 +424,8 @@ def __init__(self, name, args): } special_tokens_v5 = {'AGENT_ACTS:'} special_tokens_v7 = {'ACTS:'} - self.special_tokens = special_tokens_v1 | special_tokens_v2 | special_tokens_v5 | special_tokens_v7 + special_tokens_v9 = {'USER_ACTS:'} + self.special_tokens = special_tokens_v1 | special_tokens_v2 | special_tokens_v5 | special_tokens_v7 | special_tokens_v9 self._metrics = 'casedbleu' def utterance_field(self): From 4a260ccfe8b83a2e28a8479c447d1f5eddfac290 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 16 Sep 2021 14:35:45 -0700 Subject: [PATCH 033/102] Improve heuristics for arabic numbers --- genienlp/data_utils/almond_utils.py | 5 +++++ genienlp/model_utils/translation.py | 21 +++++++++++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/genienlp/data_utils/almond_utils.py b/genienlp/data_utils/almond_utils.py index 942dd143..62f1623f 100644 --- a/genienlp/data_utils/almond_utils.py +++ b/genienlp/data_utils/almond_utils.py @@ -69,6 +69,11 @@ 'da': 'Danish', } +NUMBER_MAPPING = { + 'en': ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9'), + 'fa': ('Ū°', 'Ū±', 'Ū²', 'Ū³', 'Ū“', 'Ūµ', 'Ū¶', 'Ū·', 'Ūø', 'Ū¹'), +} + CJK_RANGES = [ (ord(u"\u3300"), ord(u"\u33ff")), diff --git a/genienlp/model_utils/translation.py b/genienlp/model_utils/translation.py index 4adf103e..ab05758e 100644 --- a/genienlp/model_utils/translation.py +++ b/genienlp/model_utils/translation.py @@ -34,6 +34,8 @@ from num2words import CONVERTER_CLASSES, num2words from transformers import SPIECE_UNDERLINE, M2M100Tokenizer +from genienlp.data_utils.almond_utils import NUMBER_MAPPING + logger = logging.getLogger(__name__) @@ -131,13 +133,20 @@ def align_and_replace( expanded_matches = [cur_match] # translation turned digit into words - if ( - len(cur_match) == 1 - and cur_match[0].isdigit() - and (tgt_lang in CONVERTER_CLASSES or tgt_lang[:2] in CONVERTER_CLASSES) - ): + if len(cur_match) == 1 and cur_match[0].isdigit(): + # int converts arabic digits to english match = int(cur_match[0]) - expanded_matches.append([num2words(match, lang=tgt_lang, to='cardinal')]) + if tgt_lang in CONVERTER_CLASSES or tgt_lang[:2] in CONVERTER_CLASSES: + expanded_matches.append([num2words(match, lang=tgt_lang, to='cardinal')]) + + if any(tgt_lang.startswith(lang) for lang in ['fa', 'ar']): + match = str(match) + src_numbers = NUMBER_MAPPING['en'] + tgt_numbers = NUMBER_MAPPING['fa'] + if match in src_numbers: + index = src_numbers.index(match) + tgt_number = tgt_numbers[index] + expanded_matches.append([tgt_number]) # find translation of dates elif date_parser: From 5c9456456d3a7fa1ecea65cb2ccdc939eb88e6b0 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 16 Sep 2021 17:18:35 -0700 Subject: [PATCH 034/102] validate: fix bug in replace_match only replace match with new match not the whole input. otherwise if input contains tokens that match exactly with pattern they will be replaced too --- genienlp/validate.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index 35c3cabc..80ca1a50 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -111,9 +111,13 @@ def generate_with_model( ) -def replace_match(input, re_pattern, replacement): - match = re_pattern.search(input).group(1).strip() - return input.replace(match, replacement) +def replace_capturing_group(input, re_pattern, replacement): + # replace first captured group in the input with replacement using regex re_pattern + whole_match = re_pattern.search(input).group(0).strip() + captured_match = re_pattern.search(input).group(1).strip() + new_whole_match = whole_match.replace(captured_match, replacement) + new_input = re.sub(re_pattern, new_whole_match, input) + return new_input def generate_with_seq2seq_model_for_dialogue( @@ -210,7 +214,7 @@ def generate_with_seq2seq_model_for_dialogue( answers += batch_answer if train_target == 'dst': - input_text = replace_match(contexts[-1], state_re, new_state_text) + input_text = replace_capturing_group(contexts[-1], state_re, new_state_text) ## we always use gold history following common practice ## if you want to use predicted response instead of gold uncomment the following @@ -221,14 +225,14 @@ def generate_with_seq2seq_model_for_dialogue( new_state_text = state2span(dialogue_state, required_slots) # replace state - input_text = replace_match(contexts[-1], state_re, new_state_text) + input_text = replace_capturing_group(contexts[-1], state_re, new_state_text) elif train_target == 'response': # replace state - input_text = replace_match(contexts[-1], state_re, new_state_text) + input_text = replace_capturing_group(contexts[-1], state_re, new_state_text) # replace knowledge - input_text = replace_match(input_text, knowledge_re, new_knowledge_text) + input_text = replace_capturing_group(input_text, knowledge_re, new_knowledge_text) else: raise ValueError(f'Invalid train_target: {train_target}') From 4eda762d97acbe363921e9a3bf4b5939a828cfd7 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 17 Sep 2021 14:11:08 -0700 Subject: [PATCH 035/102] Move dialogue loop code from server to a separate file --- genienlp/__main__.py | 16 +++- genienlp/dial_validate.py | 82 ++++++----------- genienlp/run_dialogue_loop.py | 163 ++++++++++++++++++++++++++++++++++ 3 files changed, 203 insertions(+), 58 deletions(-) create mode 100644 genienlp/run_dialogue_loop.py diff --git a/genienlp/__main__.py b/genienlp/__main__.py index aef196dc..60e38d7a 100755 --- a/genienlp/__main__.py +++ b/genienlp/__main__.py @@ -30,7 +30,19 @@ import argparse -from . import arguments, cache_embeddings, calibrate, export, kfserver, predict, run_bootleg, server, train, write_kf_metrics +from . import ( + arguments, + cache_embeddings, + calibrate, + export, + kfserver, + predict, + run_bootleg, + run_dialogue_loop, + server, + train, + write_kf_metrics, +) from .ned.scripts import analyze_bootleg_results, oracle_vs_bootleg from .paraphrase import run_generation, run_lm_finetuning from .paraphrase.scripts import clean_paraphrasing_dataset, dialog_to_tsv, split_dataset, transform_dataset @@ -94,6 +106,8 @@ # kf commands 'kfserver': ('Export KFServing interface to predict', server.parse_argv, kfserver.main), 'write-kf-metrics': ('Write KF evaluation metrics', write_kf_metrics.parse_argv, write_kf_metrics.main), + # bitod + 'run-dialogue-loop': ('Interact with dialogue agent', run_dialogue_loop.parse_argv, run_dialogue_loop.main), } diff --git a/genienlp/dial_validate.py b/genienlp/dial_validate.py index 2b51c720..a032357b 100644 --- a/genienlp/dial_validate.py +++ b/genienlp/dial_validate.py @@ -10,32 +10,26 @@ from BiToD.knowledgebase import api from BiToD.knowledgebase.en_zh_mappings import API_MAP from BiToD.utils import knowledge2span, read_require_slots, span2state, state2constraints, state2span +from termcolor import colored from genienlp.data_utils.example import NumericalizedExamples, SequentialField -from genienlp.util import GenerationOutput logger = logging.getLogger(__name__) -def generate_with_seq2seq_model_for_dialogue_interactive( - model, - numericalizer, - task, - args, - output_predictions_only=False, - original_order=None, - disable_progbar=True, -) -> GenerationOutput: - from termcolor import colored +def generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e2e_task, nlg_task): bitod_preds = dict() predictions = [] - example_ids = [] - answers = [] - contexts = [] - device = model.device + e2e_numericalizer = e2e_model.numericalizer + # nlg_numericalizer = nlg_model.numericalizer + + e2e_args = e2e_model.args + # nlg_args = nlg_model.args + + device = e2e_model.device required_slots = read_require_slots() required_slots = {API_MAP[k]: v for k, v in required_slots.items()} @@ -71,19 +65,10 @@ def generate_with_seq2seq_model_for_dialogue_interactive( else: print(colored('SYSTEM: Hello! What are you looking for today?', 'red', attrs=['bold'])) - # Hello, I am looking for a restaurant with Vegan Options. # construct new input raw_user_input = input(colored('USER: ', 'green', attrs=['bold'])) if raw_user_input == 'RESET': - generate_with_seq2seq_model_for_dialogue_interactive( - model, - numericalizer, - task, - args, - output_predictions_only=False, - original_order=None, - disable_progbar=True, - ) + generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e2e_task, nlg_task) break elif raw_user_input == 'END': sys.exit(0) @@ -110,7 +95,7 @@ def generate_with_seq2seq_model_for_dialogue_interactive( else: raise ValueError(f'Invalid train_target: {train_target}') - tokenized_contexts = numericalizer.encode_batch([input_text], field_name='context', features=None)[0] + tokenized_contexts = e2e_numericalizer.encode_batch([input_text], field_name='context', features=None)[0] numericalized_turn = NumericalizedExamples( example_id=[str(turn_id)], @@ -123,27 +108,27 @@ def generate_with_seq2seq_model_for_dialogue_interactive( answer=SequentialField(value=None, length=None, limited=None, feature=None), ) - generated = model.generate( + generated = e2e_model.generate( numericalized_turn, - max_output_length=args.max_output_length, - num_outputs=args.num_outputs[hyperparameter_idx], - temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, - repetition_penalty=args.repetition_penalty[hyperparameter_idx], - top_k=args.top_k[hyperparameter_idx], - top_p=args.top_p[hyperparameter_idx], - num_beams=args.num_beams[hyperparameter_idx], - num_beam_groups=args.num_beam_groups[hyperparameter_idx], - diversity_penalty=args.diversity_penalty[hyperparameter_idx], - no_repeat_ngram_size=args.no_repeat_ngram_size[hyperparameter_idx], - do_sample=args.temperature[hyperparameter_idx] != 0, + max_output_length=e2e_args.max_output_length, + num_outputs=e2e_args.num_outputs[hyperparameter_idx], + temperature=e2e_args.temperature[hyperparameter_idx] if e2e_args.temperature[hyperparameter_idx] > 0 else 1.0, + repetition_penalty=e2e_args.repetition_penalty[hyperparameter_idx], + top_k=e2e_args.top_k[hyperparameter_idx], + top_p=e2e_args.top_p[hyperparameter_idx], + num_beams=e2e_args.num_beams[hyperparameter_idx], + num_beam_groups=e2e_args.num_beam_groups[hyperparameter_idx], + diversity_penalty=e2e_args.diversity_penalty[hyperparameter_idx], + no_repeat_ngram_size=e2e_args.no_repeat_ngram_size[hyperparameter_idx], + do_sample=e2e_args.temperature[hyperparameter_idx] != 0, ) partial_batch_prediction_ids = generated.sequences - partial_batch_prediction = numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] + partial_batch_prediction = e2e_numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] # post-process predictions - partial_batch_prediction = task.postprocess_prediction(turn_id, partial_batch_prediction) + partial_batch_prediction = e2e_task.postprocess_prediction(turn_id, partial_batch_prediction) # put them into the right array batch_prediction.append([partial_batch_prediction]) @@ -222,20 +207,3 @@ def generate_with_seq2seq_model_for_dialogue_interactive( with open(f"{int(time.time())}_bitod_preds.json", 'w') as fout: ujson.dump(bitod_preds, fout, indent=2, ensure_ascii=False) - - if original_order is not None: - # sort back to the original order - original_order, example_ids, predictions, answers, contexts = [ - list(a) for a in tuple(zip(*sorted(list(zip(original_order, example_ids, predictions, answers, contexts))))) - ] - - # TODO calculate and return loss - loss = None - output = GenerationOutput(loss=loss) - - if output_predictions_only: - output.predictions = predictions - else: - output.example_ids, output.predictions, output.answers, output.contexts = example_ids, predictions, answers, contexts - - return output diff --git a/genienlp/run_dialogue_loop.py b/genienlp/run_dialogue_loop.py new file mode 100644 index 00000000..8f96bfcf --- /dev/null +++ b/genienlp/run_dialogue_loop.py @@ -0,0 +1,163 @@ +# +# Copyright (c) 2018, Salesforce, Inc. +# The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import copy +import logging +from pprint import pformat + +import torch + +from genienlp.dial_validate import generate_with_seq2seq_model_for_dialogue_interactive + +from . import models +from .arguments import check_and_update_generation_args +from .tasks.registry import get_tasks +from .util import get_devices, load_config_json, set_seed + +logger = logging.getLogger(__name__) + + +def parse_argv(parser): + parser.add_argument('--e2e_path', type=str, required=True) + parser.add_argument('--nlg_path', type=str, required=True) + + parser.add_argument( + '--devices', default=[0], nargs='+', type=int, help='a list of devices that can be used (multi-gpu currently WIP)' + ) + parser.add_argument('--seed', default=123, type=int, help='Random seed.') + parser.add_argument('--embeddings', default='.embeddings', type=str, help='where to save embeddings.') + parser.add_argument( + '--checkpoint_name', default='best.pth', help='Checkpoint file to use (relative to --path, defaults to best.pth)' + ) + + parser.add_argument('--database_dir', type=str, help='Database folder containing all relevant files') + parser.add_argument('--src_locale', default='en', help='locale tag of the input language to parse') + parser.add_argument('--tgt_locale', default='en', help='locale tag of the target language to generate') + parser.add_argument('--inference_name', default='nlp', help='name used by kfserving inference service, alphanumeric only') + + # These are generation hyperparameters. Each one can be a list of values in which case, we generate `num_outputs` outputs for each set of hyperparameters. + parser.add_argument("--num_outputs", type=int, nargs='+', default=[1], help='number of sequences to output per input') + parser.add_argument("--temperature", type=float, nargs='+', default=[0.0], help="temperature of 0 implies greedy sampling") + parser.add_argument( + "--repetition_penalty", + type=float, + nargs='+', + default=[1.0], + help="primarily useful for CTRL model; in that case, use 1.2", + ) + parser.add_argument("--top_k", type=int, nargs='+', default=[0], help='0 disables top-k filtering') + parser.add_argument("--top_p", type=float, nargs='+', default=[1.0], help='1.0 disables top-p filtering') + parser.add_argument("--num_beams", type=int, nargs='+', default=[1], help='1 disables beam seach') + parser.add_argument("--num_beam_groups", type=int, nargs='+', default=[1], help='1 disables diverse beam seach') + parser.add_argument("--diversity_penalty", type=float, nargs='+', default=[0.0], help='0 disables diverse beam seach') + parser.add_argument( + "--no_repeat_ngram_size", + type=int, + nargs='+', + default=[0], + help='ngrams of this size cannot be repeated in the output. 0 disables it.', + ) + parser.add_argument('--max_output_length', default=150, type=int, help='maximum output length for generation') + + +class DialogueLoop(object): + def __init__(self, e2e_model, nlg_model): + self.e2e_model = e2e_model + self.nlg_model = nlg_model + + def run(self): + e2e_task = list(get_tasks(['bitod'], self.e2e_model.args).values())[0] + nlg_task = list(get_tasks(['bitod_nlg'], self.nlg_model.args).values())[0] + + self.e2e_model.add_new_vocab_from_data([e2e_task]) + self.nlg_model.add_new_vocab_from_data([nlg_task]) + + self.e2e_model.set_task_dependent_generation_kwargs([e2e_task]) + self.nlg_model.set_task_dependent_generation_kwargs([nlg_task]) + + with torch.no_grad(): + generate_with_seq2seq_model_for_dialogue_interactive( + self.e2e_model, + self.nlg_model, + e2e_task, + nlg_task, + ) + + +def init(args): + + set_seed(args) + + devices = get_devices() + device = devices[0] # server only runs on a single device + + e2e_args = copy.deepcopy(args) + e2e_args.path = args.e2e_path + load_config_json(e2e_args) + check_and_update_generation_args(e2e_args) + + E2EModel = getattr(models, e2e_args.model) + e2e_model, _ = E2EModel.load( + e2e_args.path, + model_checkpoint_file=e2e_args.checkpoint_name, + args=e2e_args, + device=device, + src_lang=e2e_args.src_locale, + tgt_lang=e2e_args.tgt_locale, + ) + e2e_model.to(device) + e2e_model.eval() + logger.info(f'Arguments:\n{pformat(vars(e2e_args))}') + logger.info(f'Loading from {e2e_args.best_checkpoint}') + + nlg_args = copy.deepcopy(args) + nlg_args.path = args.nlg_path + load_config_json(nlg_args) + check_and_update_generation_args(nlg_args) + NLGModel = getattr(models, nlg_args.model) + nlg_model, _ = NLGModel.load( + nlg_args.path, + model_checkpoint_file=nlg_args.checkpoint_name, + args=nlg_args, + device=device, + src_lang=nlg_args.src_locale, + tgt_lang=nlg_args.tgt_locale, + ) + nlg_model.to(device) + nlg_model.eval() + logger.info(f'Arguments:\n{pformat(vars(nlg_args))}') + logger.info(f'Loading from {nlg_args.best_checkpoint}') + + return e2e_model, nlg_model + + +def main(args): + e2e_model, nlg_model = init(args) + loop = DialogueLoop(e2e_model, nlg_model) + loop.run() From a4166e153575a1d805a339030bf07b10f7b38ebf Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 15 Sep 2021 11:28:16 -0700 Subject: [PATCH 036/102] Add bitod_nlg task --- genienlp/dial_validate.py | 87 ++++++++++++++++++++----------- genienlp/tasks/generic_dataset.py | 8 +-- genienlp/tasks/generic_task.py | 37 ++++++++++++- 3 files changed, 97 insertions(+), 35 deletions(-) diff --git a/genienlp/dial_validate.py b/genienlp/dial_validate.py index a032357b..da43fedc 100644 --- a/genienlp/dial_validate.py +++ b/genienlp/dial_validate.py @@ -17,6 +17,42 @@ logger = logging.getLogger(__name__) +def numericalize_example(input_text, numericalizer, turn_id, device): + if isinstance(input_text, str): + input_text = [input_text] + tokenized_contexts = numericalizer.encode_batch(input_text, field_name='context', features=None)[0] + + numericalized_turn = NumericalizedExamples( + example_id=[str(turn_id)], + context=SequentialField( + value=torch.tensor([tokenized_contexts.value], device=device), + length=torch.tensor([tokenized_contexts.length], device=device), + limited=torch.tensor([tokenized_contexts.limited], device=device), + feature=None, + ), + answer=SequentialField(value=None, length=None, limited=None, feature=None), + ) + + return numericalized_turn + + +def generate(model, args, numericalized_turn, hyperparameter_idx): + return model.generate( + numericalized_turn, + max_output_length=args.max_output_length, + num_outputs=args.num_outputs[hyperparameter_idx], + temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, + repetition_penalty=args.repetition_penalty[hyperparameter_idx], + top_k=args.top_k[hyperparameter_idx], + top_p=args.top_p[hyperparameter_idx], + num_beams=args.num_beams[hyperparameter_idx], + num_beam_groups=args.num_beam_groups[hyperparameter_idx], + diversity_penalty=args.diversity_penalty[hyperparameter_idx], + no_repeat_ngram_size=args.no_repeat_ngram_size[hyperparameter_idx], + do_sample=args.temperature[hyperparameter_idx] != 0, + ) + + def generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e2e_task, nlg_task): bitod_preds = dict() @@ -24,10 +60,10 @@ def generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e predictions = [] e2e_numericalizer = e2e_model.numericalizer - # nlg_numericalizer = nlg_model.numericalizer + nlg_numericalizer = nlg_model.numericalizer e2e_args = e2e_model.args - # nlg_args = nlg_model.args + nlg_args = nlg_model.args device = e2e_model.device @@ -44,6 +80,7 @@ def generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e bitod_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)} convo_history = [] + nlg_responses = [] convo_window = 3 hyperparameter_idx = 0 @@ -61,7 +98,7 @@ def generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e if train_target == 'dst': if convo_history: - print(colored(f'SYSTEM: {convo_history[-1]}', 'red', attrs=['bold'])) + print(colored(f'SYSTEM: {nlg_responses[-1]}', 'red', attrs=['bold'])) else: print(colored('SYSTEM: Hello! What are you looking for today?', 'red', attrs=['bold'])) @@ -95,33 +132,8 @@ def generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e else: raise ValueError(f'Invalid train_target: {train_target}') - tokenized_contexts = e2e_numericalizer.encode_batch([input_text], field_name='context', features=None)[0] - - numericalized_turn = NumericalizedExamples( - example_id=[str(turn_id)], - context=SequentialField( - value=torch.tensor([tokenized_contexts.value], device=device), - length=torch.tensor([tokenized_contexts.length], device=device), - limited=torch.tensor([tokenized_contexts.limited], device=device), - feature=None, - ), - answer=SequentialField(value=None, length=None, limited=None, feature=None), - ) - - generated = e2e_model.generate( - numericalized_turn, - max_output_length=e2e_args.max_output_length, - num_outputs=e2e_args.num_outputs[hyperparameter_idx], - temperature=e2e_args.temperature[hyperparameter_idx] if e2e_args.temperature[hyperparameter_idx] > 0 else 1.0, - repetition_penalty=e2e_args.repetition_penalty[hyperparameter_idx], - top_k=e2e_args.top_k[hyperparameter_idx], - top_p=e2e_args.top_p[hyperparameter_idx], - num_beams=e2e_args.num_beams[hyperparameter_idx], - num_beam_groups=e2e_args.num_beam_groups[hyperparameter_idx], - diversity_penalty=e2e_args.diversity_penalty[hyperparameter_idx], - no_repeat_ngram_size=e2e_args.no_repeat_ngram_size[hyperparameter_idx], - do_sample=e2e_args.temperature[hyperparameter_idx] != 0, - ) + numericalized_turn = numericalize_example(input_text, e2e_numericalizer, turn_id, device) + generated = generate(e2e_model, e2e_args, numericalized_turn, hyperparameter_idx) partial_batch_prediction_ids = generated.sequences @@ -196,8 +208,21 @@ def generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e #### if train_target == 'response': + # turn dialogue acts into actual responses + numericalized_turn = numericalize_example(predictions[-1][0], e2e_numericalizer, turn_id, device) + generated = generate(nlg_model, nlg_args, numericalized_turn, hyperparameter_idx) + + partial_batch_prediction_ids = generated.sequences + + partial_batch_prediction = nlg_numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] + + # post-process predictions + partial_batch_prediction = nlg_task.postprocess_prediction(turn_id, partial_batch_prediction) + + nlg_responses.append(partial_batch_prediction) + #### save latest response - bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = predictions[-1] + bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = nlg_responses[-1] #### convo_history.append('SYSTEM: ' + predictions[-1][0]) diff --git a/genienlp/tasks/generic_dataset.py b/genienlp/tasks/generic_dataset.py index 0ceab003..8bbd351d 100644 --- a/genienlp/tasks/generic_dataset.py +++ b/genienlp/tasks/generic_dataset.py @@ -1979,14 +1979,16 @@ def __init__(self, path, *, make_example, **kwargs): with open(path) as fin: data = ujson.load(fin)['data'] for turn in data: - examples.append(make_example(turn)) + processed = make_example(turn) + if processed: + examples.append(processed) if subsample is not None and len(examples) >= subsample: break super().__init__(examples, **kwargs) - - if kwargs['e2e_evaluation']: + + if kwargs.get('e2e_evaluation', False): self.eval_sort_key_fn = None self.eval_batch_size_fn = default_batch_fn diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 0c85fa33..3811b85e 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -426,7 +426,7 @@ def __init__(self, name, args): special_tokens_v7 = {'ACTS:'} special_tokens_v9 = {'USER_ACTS:'} self.special_tokens = special_tokens_v1 | special_tokens_v2 | special_tokens_v5 | special_tokens_v7 | special_tokens_v9 - self._metrics = 'casedbleu' + self._metrics = ['em', 'casedbleu'] def utterance_field(self): return 'context' @@ -450,3 +450,38 @@ def _make_example(self, turn, **kwargs): def get_splits(self, root, **kwargs): kwargs['e2e_evaluation'] = self.args.bitod_e2e_evaluation return BiTODDataset.return_splits(path=root, make_example=self._make_example, **kwargs) + + +@register_task('bitod_nlg') +class BiTODNLG(BaseTask): + def __init__(self, name, args): + super().__init__(name, args) + self.special_tokens = {} + self._metrics = ['casedbleu'] + + def utterance_field(self): + return 'context' + + def _make_example(self, turn, **kwargs): + if 'response' not in turn: + return None + + dial_id, turn_id, answer, train_target, response = ( + turn['dial_id'], + turn['turn_id'], + turn['output_text'], + turn['train_target'], + turn['response'], + ) + question = '' + + assert train_target == 'response' + + example_id = '/'.join([dial_id, str(turn_id), train_target]) + + return Example.from_raw( + self.name + '/' + str(example_id), answer, question, response, preprocess=self.preprocess_field, lower=False + ) + + def get_splits(self, root, **kwargs): + return BiTODDataset.return_splits(path=root, make_example=self._make_example, **kwargs) From e9d5ad5bd1dbfad8c078fb0f2d254d5fc652a1c3 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 23 Sep 2021 10:49:30 -0700 Subject: [PATCH 037/102] server: remove bitod interactive mode --- genienlp/server.py | 47 +++++++------------------------------------- genienlp/validate.py | 35 +++++++++++---------------------- 2 files changed, 18 insertions(+), 64 deletions(-) diff --git a/genienlp/server.py b/genienlp/server.py index 02764836..9d00a8df 100644 --- a/genienlp/server.py +++ b/genienlp/server.py @@ -61,12 +61,7 @@ def parse_argv(parser): '--checkpoint_name', default='best.pth', help='Checkpoint file to use (relative to --path, defaults to best.pth)' ) parser.add_argument('--port', default=8401, type=int, help='TCP port to listen on') - parser.add_argument( - '--interact_mode', - default='tcp', - choices=['stdin', 'tcp', 'cmdline'], - help='Interact with stdin/stdout, user cmdline, or TCP', - ) + parser.add_argument('--stdin', action='store_true', help='Interact on stdin/stdout instead of TCP') parser.add_argument('--database_dir', type=str, help='Database folder containing all relevant files') parser.add_argument('--src_locale', default='en', help='locale tag of the input language to parse') parser.add_argument('--tgt_locale', default='en', help='locale tag of the target language to generate') @@ -124,7 +119,7 @@ def numericalize_examples(self, ex): # make a single batch with all examples return NumericalizedExamples.collate_batches(all_features, self.numericalizer, device=self.device) - def handle_request(self, request, interactive=False): + def handle_request(self, request): task_name = request['task'] if 'task' in request else 'generic' task = list(get_tasks([task_name], self.args, self._cached_task_names).values())[0] if task_name not in self._cached_task_names: @@ -165,7 +160,6 @@ def handle_request(self, request, interactive=False): self.ned_model.process_examples(examples, None, task.utterance_field) self.model.add_new_vocab_from_data([task]) - self.model.set_task_dependent_generation_kwargs([task]) batch = self.numericalize_examples(examples) try: @@ -196,18 +190,6 @@ def handle_request(self, request, interactive=False): for e_idx, estimator_scores in enumerate(output.confidence_scores): instance['score'][self.estimator_filenames[e_idx]] = float(estimator_scores[idx]) response.append(instance) - - elif interactive: - self.args.bitod_e2e_evaluation = True - output = generate_with_model( - self.model, - [batch], - self.numericalizer, - task, - self.args, - output_predictions_only=True, - interactive=True, - ) else: output = generate_with_model( self.model, [batch], self.numericalizer, task, self.args, output_predictions_only=True @@ -230,12 +212,12 @@ def handle_request(self, request, interactive=False): return response - def handle_json_request(self, line: str, interactive=False) -> str: + def handle_json_request(self, line: str) -> str: request = json.loads(line) if 'instances' in request: - return json.dumps({'id': request['id'], 'instances': self.handle_request(request, interactive)}) + '\n' + return json.dumps({'id': request['id'], 'instances': self.handle_request(request)}) + '\n' else: - response = self.handle_request(request, interactive) + response = self.handle_request(request) assert len(response) == 1 response = response[0] response['id'] = request['id'] @@ -277,29 +259,14 @@ def _run_stdin(self): except KeyboardInterrupt: pass - def _run_interactive(self): - try: - while True: - # raw_text = input('USER: ') - raw_text = '' - line = json.dumps( - {"task": "bitod", "id": "0", "context": "USER: " + raw_text.strip(), "question": "", "answer": ""} - ) - sys.stdout.write(self.handle_json_request(line, interactive=True)) - sys.stdout.flush() - except KeyboardInterrupt: - pass - def run(self): log_model_size(logger, self.model, self.args.model) self.model.to(self.device) self.model.eval() - if self.args.interact_mode == 'stdin': + if self.args.stdin: self._run_stdin() - elif self.args.interact_mode == 'cmdline': - self._run_interactive() - elif self.args.interact_mode == 'tcp': + else: self._run_tcp() diff --git a/genienlp/validate.py b/genienlp/validate.py index 80ca1a50..ed43f372 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -46,7 +46,6 @@ from .data_utils.example import NumericalizedExamples, SequentialField from .data_utils.progbar import progress_bar -from .dial_validate import generate_with_seq2seq_model_for_dialogue_interactive from .metrics import calculate_and_reduce_metrics from .models import TransformerForSequenceClassification, TransformerForTokenClassification from .util import GenerationOutput, merge_translated_sentences @@ -65,32 +64,20 @@ def generate_with_model( original_order=None, confidence_estimators=None, disable_progbar=True, - interactive=False, eval_dir=None, ): if args.bitod_e2e_evaluation: - if interactive: - return generate_with_seq2seq_model_for_dialogue_interactive( - model, - numericalizer, - task, - args, - output_predictions_only=output_predictions_only, - original_order=original_order, - disable_progbar=disable_progbar, - ) - else: - return generate_with_seq2seq_model_for_dialogue( - model, - data_iterator, - numericalizer, - task, - args, - eval_dir, - output_predictions_only=output_predictions_only, - original_order=original_order, - disable_progbar=disable_progbar, - ) + return generate_with_seq2seq_model_for_dialogue( + model, + data_iterator, + numericalizer, + task, + args, + eval_dir, + output_predictions_only=output_predictions_only, + original_order=original_order, + disable_progbar=disable_progbar, + ) elif isinstance(model, (TransformerForTokenClassification, TransformerForSequenceClassification)): return generate_with_classification_model( From 2369fd47dadee5ccf3a1fb7de271824f99f28b12 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 23 Sep 2021 13:25:21 -0700 Subject: [PATCH 038/102] validate: use lang to choose the db for api calling --- genienlp/validate.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index ed43f372..d6f16961 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -295,7 +295,11 @@ def generate_with_seq2seq_model_for_dialogue( knowledge = defaultdict(dict) try: - msg = api.call_api(r_en_API_MAP.get(api_name, api_name), constraints=[constraints]) + msg = api.call_api( + r_en_API_MAP.get(api_name, api_name), + constraints=[constraints], + lang=numericalizer._tokenizer.src_lang, + ) except Exception as e: logger.error(f'Error: {e}') logger.error( From a85a67aad76e957ca72368b2cd5d50f6adffcc9a Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 23 Sep 2021 15:20:27 -0700 Subject: [PATCH 039/102] dial_loop: add template option for nlg --- genienlp/dial_validate.py | 23 ++++++++------- genienlp/run_dialogue_loop.py | 54 +++++++++++++++++++---------------- 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/genienlp/dial_validate.py b/genienlp/dial_validate.py index da43fedc..7b30a888 100644 --- a/genienlp/dial_validate.py +++ b/genienlp/dial_validate.py @@ -60,13 +60,13 @@ def generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e predictions = [] e2e_numericalizer = e2e_model.numericalizer - nlg_numericalizer = nlg_model.numericalizer - e2e_args = e2e_model.args - nlg_args = nlg_model.args - device = e2e_model.device + if e2e_args.nlg_type == 'neural': + nlg_numericalizer = nlg_model.numericalizer + nlg_args = nlg_model.args + required_slots = read_require_slots() required_slots = {API_MAP[k]: v for k, v in required_slots.items()} api_names = list(required_slots.keys()) @@ -209,15 +209,18 @@ def generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e if train_target == 'response': # turn dialogue acts into actual responses - numericalized_turn = numericalize_example(predictions[-1][0], e2e_numericalizer, turn_id, device) - generated = generate(nlg_model, nlg_args, numericalized_turn, hyperparameter_idx) + if e2e_args.nlg_type == 'neural': + numericalized_turn = numericalize_example(predictions[-1][0], e2e_numericalizer, turn_id, device) + generated = generate(nlg_model, nlg_args, numericalized_turn, hyperparameter_idx) - partial_batch_prediction_ids = generated.sequences + partial_batch_prediction_ids = generated.sequences - partial_batch_prediction = nlg_numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] + partial_batch_prediction = nlg_numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] - # post-process predictions - partial_batch_prediction = nlg_task.postprocess_prediction(turn_id, partial_batch_prediction) + # post-process predictions + partial_batch_prediction = nlg_task.postprocess_prediction(turn_id, partial_batch_prediction) + else: + partial_batch_prediction = nlg_model.generate(predictions[-1][0]) nlg_responses.append(partial_batch_prediction) diff --git a/genienlp/run_dialogue_loop.py b/genienlp/run_dialogue_loop.py index 8f96bfcf..613d4276 100644 --- a/genienlp/run_dialogue_loop.py +++ b/genienlp/run_dialogue_loop.py @@ -32,6 +32,7 @@ from pprint import pformat import torch +from BiToD.templates.generate_template_response import TemplateResponseGenerator from genienlp.dial_validate import generate_with_seq2seq_model_for_dialogue_interactive @@ -45,7 +46,8 @@ def parse_argv(parser): parser.add_argument('--e2e_path', type=str, required=True) - parser.add_argument('--nlg_path', type=str, required=True) + parser.add_argument('--nlg_path', type=str) + parser.add_argument('--nlg_type', type=str, choices=['neural', 'template'], default='template') parser.add_argument( '--devices', default=[0], nargs='+', type=int, help='a list of devices that can be used (multi-gpu currently WIP)' @@ -93,13 +95,15 @@ def __init__(self, e2e_model, nlg_model): def run(self): e2e_task = list(get_tasks(['bitod'], self.e2e_model.args).values())[0] - nlg_task = list(get_tasks(['bitod_nlg'], self.nlg_model.args).values())[0] - self.e2e_model.add_new_vocab_from_data([e2e_task]) - self.nlg_model.add_new_vocab_from_data([nlg_task]) - self.e2e_model.set_task_dependent_generation_kwargs([e2e_task]) - self.nlg_model.set_task_dependent_generation_kwargs([nlg_task]) + + if self.e2e_model.args.nlg_type == 'neural': + nlg_task = list(get_tasks(['bitod_nlg'], self.nlg_model.args).values())[0] + self.nlg_model.add_new_vocab_from_data([nlg_task]) + self.nlg_model.set_task_dependent_generation_kwargs([nlg_task]) + else: + nlg_task = None with torch.no_grad(): generate_with_seq2seq_model_for_dialogue_interactive( @@ -111,7 +115,6 @@ def run(self): def init(args): - set_seed(args) devices = get_devices() @@ -136,23 +139,26 @@ def init(args): logger.info(f'Arguments:\n{pformat(vars(e2e_args))}') logger.info(f'Loading from {e2e_args.best_checkpoint}') - nlg_args = copy.deepcopy(args) - nlg_args.path = args.nlg_path - load_config_json(nlg_args) - check_and_update_generation_args(nlg_args) - NLGModel = getattr(models, nlg_args.model) - nlg_model, _ = NLGModel.load( - nlg_args.path, - model_checkpoint_file=nlg_args.checkpoint_name, - args=nlg_args, - device=device, - src_lang=nlg_args.src_locale, - tgt_lang=nlg_args.tgt_locale, - ) - nlg_model.to(device) - nlg_model.eval() - logger.info(f'Arguments:\n{pformat(vars(nlg_args))}') - logger.info(f'Loading from {nlg_args.best_checkpoint}') + if args.nlg_type == 'neural': + nlg_args = copy.deepcopy(args) + nlg_args.path = args.nlg_path + load_config_json(nlg_args) + check_and_update_generation_args(nlg_args) + NLGModel = getattr(models, nlg_args.model) + nlg_model, _ = NLGModel.load( + nlg_args.path, + model_checkpoint_file=nlg_args.checkpoint_name, + args=nlg_args, + device=device, + src_lang=nlg_args.src_locale, + tgt_lang=nlg_args.tgt_locale, + ) + nlg_model.to(device) + nlg_model.eval() + logger.info(f'Arguments:\n{pformat(vars(nlg_args))}') + logger.info(f'Loading from {nlg_args.best_checkpoint}') + else: + nlg_model = TemplateResponseGenerator(args.tgt_locale, filename='translated') return e2e_model, nlg_model From 1a7f810bd1257a38acb688bd1949607c7d7d6c12 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 24 Sep 2021 10:36:26 -0700 Subject: [PATCH 040/102] optimizer: set scale_parameter to False for Adafactor --- genienlp/model_utils/optimizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/genienlp/model_utils/optimizer.py b/genienlp/model_utils/optimizer.py index 88941c03..b3fc05e8 100644 --- a/genienlp/model_utils/optimizer.py +++ b/genienlp/model_utils/optimizer.py @@ -55,7 +55,9 @@ def init_opt(args, model, logger): elif args.optimizer == 'adamw': opt = AdamW(model.params, lr=args.lr_multiply, weight_decay=args.weight_decay) elif args.optimizer == 'adafactor': - opt = Adafactor(model.params, lr=args.lr_multiply, weight_decay=args.weight_decay, relative_step=False) + opt = Adafactor( + model.params, lr=args.lr_multiply, weight_decay=args.weight_decay, relative_step=False, scale_parameter=False + ) elif args.optimizer == 'radam': import radam From ca0032a936f59b5debe9068017c3f5f74484cd84 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 24 Sep 2021 11:15:23 -0700 Subject: [PATCH 041/102] validate: fix minor bug --- genienlp/validate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index d6f16961..b9dcf8a6 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -294,6 +294,7 @@ def generate_with_seq2seq_model_for_dialogue( # domain = api_name.split(" ")[0] knowledge = defaultdict(dict) + msg = [0, 0, 0] try: msg = api.call_api( r_en_API_MAP.get(api_name, api_name), @@ -306,7 +307,6 @@ def generate_with_seq2seq_model_for_dialogue( f'Failed API call with api_name: {api_name}, constraints: {constraints},' f' processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' ) - msg = [0, 0, 0] if int(msg[1]) <= 0: logger.warning( From 26a031e4272f3e41683d4b8e892b0ce1f3bbb18f Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 24 Sep 2021 13:38:04 -0700 Subject: [PATCH 042/102] bitod: add special tokens for v11 --- genienlp/tasks/generic_task.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 3811b85e..53864e60 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -425,7 +425,15 @@ def __init__(self, name, args): special_tokens_v5 = {'AGENT_ACTS:'} special_tokens_v7 = {'ACTS:'} special_tokens_v9 = {'USER_ACTS:'} - self.special_tokens = special_tokens_v1 | special_tokens_v2 | special_tokens_v5 | special_tokens_v7 | special_tokens_v9 + special_tokens_v11 = {'', '', ''} + self.special_tokens = ( + special_tokens_v1 + | special_tokens_v2 + | special_tokens_v5 + | special_tokens_v7 + | special_tokens_v9 + | special_tokens_v11 + ) self._metrics = ['em', 'casedbleu'] def utterance_field(self): From d797532916ea5169756eabf1a785a4f8bd0bf5bc Mon Sep 17 00:00:00 2001 From: mehrad Date: Sun, 26 Sep 2021 15:20:58 -0700 Subject: [PATCH 043/102] bitod: add special tokens for v13 --- genienlp/tasks/generic_task.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 53864e60..18e2d1e1 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -426,6 +426,7 @@ def __init__(self, name, args): special_tokens_v7 = {'ACTS:'} special_tokens_v9 = {'USER_ACTS:'} special_tokens_v11 = {'', '', ''} + special_tokens_v13 = {'PREV_AGENT_ACTS'} self.special_tokens = ( special_tokens_v1 | special_tokens_v2 @@ -433,6 +434,7 @@ def __init__(self, name, args): | special_tokens_v7 | special_tokens_v9 | special_tokens_v11 + | special_tokens_v13 ) self._metrics = ['em', 'casedbleu'] From 267060b3442e7f1ce91d703114152383ff77fc77 Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 27 Sep 2021 10:44:56 -0700 Subject: [PATCH 044/102] .pre-commit-config: version updates --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 48898281..1eaedd5e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,16 +8,16 @@ repos: - id: trailing-whitespace exclude: ^(tests/dataset/|tests/database/|tests/expected_results/) - repo: https://github.com/hadialqattan/pycln - rev: 0.0.4 + rev: v1.0.3 hooks: - id: pycln args: [--config=pyproject.toml] - repo: https://github.com/PyCQA/isort - rev: 5.9.2 + rev: 5.9.3 hooks: - id: isort - repo: https://github.com/psf/black - rev: 21.7b0 + rev: 21.9b0 hooks: - id: black language_version: python3.8 From ba22f48a72af7c66bb45f926f4150529b1eb49be Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 28 Sep 2021 14:25:28 -0700 Subject: [PATCH 045/102] run_dialogue_loop: choose template name --- genienlp/run_dialogue_loop.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/genienlp/run_dialogue_loop.py b/genienlp/run_dialogue_loop.py index 613d4276..631e41e8 100644 --- a/genienlp/run_dialogue_loop.py +++ b/genienlp/run_dialogue_loop.py @@ -47,7 +47,9 @@ def parse_argv(parser): parser.add_argument('--e2e_path', type=str, required=True) parser.add_argument('--nlg_path', type=str) - parser.add_argument('--nlg_type', type=str, choices=['neural', 'template'], default='template') + parser.add_argument( + '--nlg_type', type=str, choices=['neural', 'template-translated', 'template-human'], default='template' + ) parser.add_argument( '--devices', default=[0], nargs='+', type=int, help='a list of devices that can be used (multi-gpu currently WIP)' @@ -158,7 +160,8 @@ def init(args): logger.info(f'Arguments:\n{pformat(vars(nlg_args))}') logger.info(f'Loading from {nlg_args.best_checkpoint}') else: - nlg_model = TemplateResponseGenerator(args.tgt_locale, filename='translated') + _, filename = args.nlg_type.split('-') + nlg_model = TemplateResponseGenerator(args.tgt_locale, filename=filename) return e2e_model, nlg_model From 0248986647a8fd5f3372c291ddea5e365ab36a99 Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 28 Sep 2021 14:43:41 -0700 Subject: [PATCH 046/102] dial_validat: start conversation in target language --- genienlp/dial_validate.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/genienlp/dial_validate.py b/genienlp/dial_validate.py index 7b30a888..3a8702d4 100644 --- a/genienlp/dial_validate.py +++ b/genienlp/dial_validate.py @@ -16,6 +16,12 @@ logger = logging.getLogger(__name__) +INIT_SYS_MESSAGE = { + 'en': 'Hello! How can I help you today?', + 'fa': 'Ų³Ł„Ų§Ł…! Ų§Ł…Ų±ŁˆŲ² Ś†Ų·ŁˆŲ± Ł…ŪŒ ŲŖŁˆŲ§Ł†Ł… ŲØŁ‡ Ų“Ł…Ų§ Ś©Ł…Ś© Ś©Ł†Ł…ŲŸ', + 'zh': '你儽ļ¼ ęˆ‘ä»Šå¤©čƒ½åø®åˆ°ä½ ä»€ä¹ˆļ¼Ÿ', +} + def numericalize_example(input_text, numericalizer, turn_id, device): if isinstance(input_text, str): @@ -63,6 +69,8 @@ def generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e e2e_args = e2e_model.args device = e2e_model.device + tgt_lang = e2e_model.tgt_lang + if e2e_args.nlg_type == 'neural': nlg_numericalizer = nlg_model.numericalizer nlg_args = nlg_model.args @@ -100,7 +108,8 @@ def generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e if convo_history: print(colored(f'SYSTEM: {nlg_responses[-1]}', 'red', attrs=['bold'])) else: - print(colored('SYSTEM: Hello! What are you looking for today?', 'red', attrs=['bold'])) + tgt_lang = tgt_lang[:2] + print(colored(f'SYSTEM: {INIT_SYS_MESSAGE[tgt_lang]}', 'red', attrs=['bold'])) # construct new input raw_user_input = input(colored('USER: ', 'green', attrs=['bold'])) From 6fdcc6b5bb4e2e3db36568eb8fd0afb5f8b6f404 Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 1 Dec 2021 13:43:58 -0800 Subject: [PATCH 047/102] Fix when loss is chosen as deca_score --- genienlp/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/genienlp/train.py b/genienlp/train.py index 6c566580..f9847787 100644 --- a/genienlp/train.py +++ b/genienlp/train.py @@ -247,7 +247,8 @@ def do_validate( deca_metric = val_task.metrics[0] if deca_metric == 'loss': - deca_score += val_loss + # negative since lower loss corresponds to better performance + deca_score += -1.0 * val_loss else: deca_score += metric_dict[deca_metric] From 9fadb55b1878227ed46495ed7fc8bedb6e05094a Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 2 Dec 2021 10:34:31 -0800 Subject: [PATCH 048/102] predict: ignore loss in validation metrics --- genienlp/predict.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/genienlp/predict.py b/genienlp/predict.py index 3eb214d7..b22fd874 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -547,10 +547,18 @@ def main(args): if args.override_valid_metrics: assert len(args.override_valid_metrics) == len(args.tasks) + new_metrics = [] for task, metrics in zip(args.tasks, args.override_valid_metrics): - # backward compatibility for models validated on sacrebleu (now casedbleu) - metrics = [m if m != 'sacrebleu' else 'casedbleu' for m in metrics] - task.metrics = metrics + for m in metrics: + # remove loss from validation metrics + if m == 'loss': + continue + # backward compatibility for models validated on sacrebleu (now casedbleu) + if m == 'sacrebleu': + m = 'casedblue' + new_metrics.append(m) + + task.metrics = new_metrics if len(devices) > 1: logger.info(f'Independent multi-GPU generation on following devices: {devices}') From 20aae07dbc23e852b7d2393455673d91e665ddeb Mon Sep 17 00:00:00 2001 From: mehrad Date: Sat, 4 Dec 2021 16:48:15 -0800 Subject: [PATCH 049/102] min_output_length is tunable now --- genienlp/arguments.py | 9 ++++++++- genienlp/dial_validate.py | 1 + genienlp/models/transformer_lstm.py | 5 +++-- genienlp/models/transformer_seq2seq.py | 3 ++- genienlp/predict.py | 7 +++++++ genienlp/server.py | 12 +++++++++--- genienlp/util.py | 1 + genienlp/validate.py | 2 ++ 8 files changed, 33 insertions(+), 7 deletions(-) diff --git a/genienlp/arguments.py b/genienlp/arguments.py index 4e10bfff..3665ac6c 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -186,6 +186,13 @@ def parse_argv(parser): parser.add_argument('--eval_set_name', type=str, help='Evaluation dataset name to use during training') parser.add_argument('--max_output_length', default=150, type=int, help='maximum output length for generation') + parser.add_argument( + '--min_output_length', + default=3, + type=int, + help='maximum output length for generation; ' + 'default is 3 for most multilingual models: BOS, language code, and one token. otherwise it is 2', + ) parser.add_argument('--max_generative_vocab', default=50000, type=int, help='max vocabulary for the generative softmax') parser.add_argument('--subsample', default=20000000, type=int, help='subsample the datasets') parser.add_argument('--preserve_case', action='store_false', dest='lower', help='whether to preserve casing for all text') @@ -673,6 +680,6 @@ def post_parse_train_specific(args): setattr(args, x, os.path.join(args.root, getattr(args, x))) save_args(args, force_overwrite=True) - + args = check_and_update_generation_args(args) return args diff --git a/genienlp/dial_validate.py b/genienlp/dial_validate.py index 3a8702d4..4c8fdd89 100644 --- a/genienlp/dial_validate.py +++ b/genienlp/dial_validate.py @@ -46,6 +46,7 @@ def generate(model, args, numericalized_turn, hyperparameter_idx): return model.generate( numericalized_turn, max_output_length=args.max_output_length, + min_output_length=args.min_output_length, num_outputs=args.num_outputs[hyperparameter_idx], temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, repetition_penalty=args.repetition_penalty[hyperparameter_idx], diff --git a/genienlp/models/transformer_lstm.py b/genienlp/models/transformer_lstm.py index a4d6f655..5271c784 100644 --- a/genienlp/models/transformer_lstm.py +++ b/genienlp/models/transformer_lstm.py @@ -202,6 +202,7 @@ def generate( self, batch, max_output_length, + min_output_length, num_outputs, temperature, repetition_penalty, @@ -226,7 +227,7 @@ def generate( input_ids=input_ids, batch=batch, max_length=max_output_length, - min_length=2, # generate at least one token after BOS + min_length=min_output_length, bos_token_id=self.decoder.init_idx, pad_token_id=self.numericalizer.decoder_vocab.pad_idx, early_stopping=True, @@ -241,7 +242,7 @@ def generate( diversity_penalty=diversity_penalty, no_repeat_ngram_size=no_repeat_ngram_size, do_sample=do_sample, - generation_dict={'max_output_length': max_output_length}, + generation_dict={'max_output_length': max_output_length, 'min_output_length': min_output_length}, encoder_output=encoder_output, output_scores=self._output_scores, output_attentions=self._output_attentions, diff --git a/genienlp/models/transformer_seq2seq.py b/genienlp/models/transformer_seq2seq.py index cfeb0c41..cd8f0af5 100644 --- a/genienlp/models/transformer_seq2seq.py +++ b/genienlp/models/transformer_seq2seq.py @@ -169,6 +169,7 @@ def generate( self, batch, max_output_length, + min_output_length, num_outputs, temperature, repetition_penalty, @@ -187,7 +188,7 @@ def generate( generated = self.model.generate( input_ids=input_ids, max_length=max_output_length, - min_length=3, # generate at least one token after BOS and language code + min_length=min_output_length, bos_token_id=self.numericalizer.init_id, pad_token_id=self.numericalizer.pad_id, early_stopping=False, diff --git a/genienlp/predict.py b/genienlp/predict.py index b22fd874..3d4f6b00 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -164,6 +164,13 @@ def parse_argv(parser): help='ngrams of this size cannot be repeated in the output. 0 disables it.', ) parser.add_argument('--max_output_length', default=150, type=int, help='maximum output length for generation') + parser.add_argument( + '--min_output_length', + default=3, + type=int, + help='maximum output length for generation; ' + 'default is 3 for most multilingual models: BOS, language code, and one token. otherwise it is 2', + ) # These are used for confidence calibration parser.add_argument( diff --git a/genienlp/server.py b/genienlp/server.py index c1b220dd..9f700ba0 100644 --- a/genienlp/server.py +++ b/genienlp/server.py @@ -61,6 +61,7 @@ 'repetition_penalty', 'temperature', 'max_output_length', + 'min_output_length', 'src_locale', 'tgt_locale', 'do_alignment', @@ -110,6 +111,13 @@ def parse_argv(parser): help='ngrams of this size cannot be repeated in the output. 0 disables it.', ) parser.add_argument('--max_output_length', default=150, type=int, help='maximum output length for generation') + parser.add_argument( + '--min_output_length', + default=3, + type=int, + help='maximum output length for generation; ' + 'default is 3 for most multilingual models: BOS, language code, and one token. otherwise it is 2', + ) # for confidence estimation: parser.add_argument( @@ -231,9 +239,7 @@ def _predict_batch(self, batch, task, args): instance['score'][self.estimator_filenames[e_idx]] = float(estimator_scores[idx]) response.append(instance) else: - output = generate_with_model( - self.model, [batch], self.numericalizer, task, args, output_predictions_only=True - ) + output = generate_with_model(self.model, [batch], self.numericalizer, task, args, output_predictions_only=True) if sum(args.num_outputs) > 1: response = [] for idx, predictions in enumerate(output.predictions): diff --git a/genienlp/util.py b/genienlp/util.py index 7f9a5cdc..a273f27d 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -841,6 +841,7 @@ def load_config_json(args): 'repetition_penalty', 'temperature', 'max_output_length', + 'min_output_length', 'reduce_metrics', 'database_dir', ] diff --git a/genienlp/validate.py b/genienlp/validate.py index 523b51f9..e224d357 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -240,6 +240,7 @@ def generate_with_seq2seq_model_for_dialogue( generated = model.generate( numericalized_turn, max_output_length=args.max_output_length, + min_output_length=args.min_output_length, num_outputs=args.num_outputs[hyperparameter_idx], temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, repetition_penalty=args.repetition_penalty[hyperparameter_idx], @@ -440,6 +441,7 @@ def generate_with_seq2seq_model( generated = model.generate( batch, max_output_length=args.max_output_length, + min_output_length=args.min_output_length, num_outputs=args.num_outputs[hyperparameter_idx] if args.temperature[hyperparameter_idx] != 0 else 1, temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, repetition_penalty=args.repetition_penalty[hyperparameter_idx], From 784e34f404e0f439d4bd427ccb21d16ce03bb93e Mon Sep 17 00:00:00 2001 From: mehrad Date: Sun, 5 Dec 2021 11:11:02 -0800 Subject: [PATCH 050/102] predict: remove defaults for max_output_length and min_output_length so they're read from train args --- genienlp/predict.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/genienlp/predict.py b/genienlp/predict.py index 3d4f6b00..e44a2c1f 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -163,10 +163,9 @@ def parse_argv(parser): default=[0], help='ngrams of this size cannot be repeated in the output. 0 disables it.', ) - parser.add_argument('--max_output_length', default=150, type=int, help='maximum output length for generation') + parser.add_argument('--max_output_length', type=int, help='maximum output length for generation') parser.add_argument( '--min_output_length', - default=3, type=int, help='maximum output length for generation; ' 'default is 3 for most multilingual models: BOS, language code, and one token. otherwise it is 2', From 11a000d567113245d02fa3cebfd9847859a22722 Mon Sep 17 00:00:00 2001 From: mehrad Date: Sun, 5 Dec 2021 12:15:09 -0800 Subject: [PATCH 051/102] Add polynomial learning rate scheduler --- genienlp/arguments.py | 14 +++++++++++++- genienlp/model_utils/optimizer.py | 9 +++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/genienlp/arguments.py b/genienlp/arguments.py index 3665ac6c..efd143c8 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -325,7 +325,7 @@ def parse_argv(parser): '--lr_schedule', type=str, default='transformer', - choices=['transformer', 'constant', 'linear', 'sgd', 'cosine', 'multi_linear_cosine'], + choices=['transformer', 'constant', 'linear', 'sgd', 'cosine', 'polynomial', 'multi_linear_cosine'], help='The learning rate strategy. All of them can be used with or without warmup.', ) parser.add_argument( @@ -334,6 +334,18 @@ def parse_argv(parser): type=float, help='Multiplier for the `transformer` learning rate scheduler, constant value for `constant` and maximum value for `linear` and `cosine` schedulers.', ) + parser.add_argument( + '--lr_poly_end', + default=1e-7, + type=float, + help='Final learning rate for polynomial learning rate scheduler', + ) + parser.add_argument( + '--lr_poly_power', + default=1.0, + type=float, + help='Decay power factor for polynomial learning rate scheduler', + ) parser.add_argument('--weight_decay', default=0.0, type=float, help='weight L2 regularization') parser.add_argument( '-gas', diff --git a/genienlp/model_utils/optimizer.py b/genienlp/model_utils/optimizer.py index e22734e2..12318862 100644 --- a/genienlp/model_utils/optimizer.py +++ b/genienlp/model_utils/optimizer.py @@ -8,6 +8,7 @@ get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, ) @@ -62,6 +63,14 @@ def init_opt(args, model, logger): num_training_steps=num_training_steps, num_warmup_steps=args.warmup, ) + elif args.lr_schedule == 'polynomial': + scheduler = get_polynomial_decay_schedule_with_warmup( + opt, + num_training_steps=num_training_steps, + num_warmup_steps=args.warmup, + lr_end=args.lr_poly_end, + power=args.lr_poly_power, + ) elif args.lr_schedule == 'cosine': scheduler = get_cosine_schedule_with_warmup( opt, From c17ada2492d9d281c4a4be4ad448b7e612ee4d0a Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 6 Dec 2021 11:25:17 -0800 Subject: [PATCH 052/102] metrics: calculate em for response (acts) too --- genienlp/metrics.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index b7fd81a6..17795c71 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -518,7 +518,7 @@ def computeBITOD(greedy, answer, tgt_lang): subtask_metrics_dict = defaultdict(tuple) subtasks = ['dst', 'api', 'response'] - subtask_metrics = [['em'], ['em'], ['casedbleu']] + subtask_metrics = [['em'], ['em'], ['em']] subtask_weights = [1, 1, 1] for t in range(len(subtasks)): @@ -532,7 +532,7 @@ def computeBITOD(greedy, answer, tgt_lang): subtask_metrics_dict[subtasks[t]] = (sub_metrics, len(golds), subtask_weights[t]) # TODO how should we aggregate? - bitod_score, JGA, response_bleu, api_em = 0.0, 0.0, 0.0, 0.0 + bitod_score, JGA, response_em, api_em = 0.0, 0.0, 0.0, 0.0 weighted_num_examples = 0 for subtask, (sub_metrics, num_ex, weight) in subtask_metrics_dict.items(): if subtask == 'dst': @@ -542,13 +542,13 @@ def computeBITOD(greedy, answer, tgt_lang): bitod_score += weight * (sub_metrics['em'] * num_ex) api_em = sub_metrics['em'] elif subtask == 'response': - bitod_score += weight * (sub_metrics['casedbleu'] * num_ex) - response_bleu = sub_metrics['casedbleu'] + bitod_score += weight * (sub_metrics['em'] * num_ex) + response_em = sub_metrics['em'] weighted_num_examples += weight * num_ex bitod_score /= weighted_num_examples - return bitod_score, JGA, response_bleu, api_em + return bitod_score, JGA, response_em, api_em def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): @@ -571,10 +571,10 @@ def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): if not isinstance(answer[0], list): answer = [[a] for a in answer] if 'bitod_score' in requested_metrics: - requested_metrics += ['JGA', 'response_bleu', 'api_em'] - bitod_score, JGA, response_bleu, api_em = computeBITOD(greedy, answer, lang) - metric_keys += ['bitod_score', 'JGA', 'response_bleu', 'api_em'] - metric_values += [bitod_score, JGA, response_bleu, api_em] + requested_metrics += ['JGA', 'response_em', 'api_em'] + bitod_score, JGA, response_em, api_em = computeBITOD(greedy, answer, lang) + metric_keys += ['bitod_score', 'JGA', 'response_em', 'api_em'] + metric_values += [bitod_score, JGA, response_em, api_em] if 'lfem' in requested_metrics: lfem, answer = computeLFEM(greedy, answer) metric_keys += ['lfem'] From a6a67379b8d599e71b4b317021457f15957f07d0 Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 6 Dec 2021 11:25:27 -0800 Subject: [PATCH 053/102] Fix HKMTR response missing price and estimated_time slots --- genienlp/validate.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index e224d357..03c7a00a 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -40,7 +40,7 @@ from BiToD.evaluate import r_en_API_MAP from BiToD.knowledgebase import api from BiToD.preprocess import API_MAP, knowledge2span, read_require_slots, state2span -from BiToD.utils import span2state, state2constraints +from BiToD.utils import action2span, span2action, span2state, state2constraints from dateparser.languages import default_loader from transformers import MarianTokenizer @@ -146,6 +146,7 @@ def generate_with_seq2seq_model_for_dialogue( hyperparameter_idx = 0 cur_dial_id = '' + knowledge = None device = model.device @@ -258,6 +259,23 @@ def generate_with_seq2seq_model_for_dialogue( partial_batch_prediction = numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] # post-process predictions + lang = numericalizer._tokenizer.src_lang[:2] + if ( + train_target == 'response' + and re.search(rf'\( HKMTR {lang} \)', partial_batch_prediction) + and 'shortest_path' in partial_batch_prediction + ): + action_dict = span2action(partial_batch_prediction, api_names) + domain = f'HKMTR {lang}' + metro_slots = set(item['slot'] for item in action_dict[domain]) + for slot in ['estimated_time', 'price']: + if knowledge and slot in knowledge[domain] and slot not in metro_slots: + action_dict[domain].append( + {'act': 'offer', 'slot': slot, 'relation': 'equal_to', 'value': [knowledge[domain][slot]]} + ) + + partial_batch_prediction = action2span(action_dict[domain], domain, lang) + partial_batch_prediction = task.postprocess_prediction(batch_example_ids[0], partial_batch_prediction) # put them into the right array From 216e1aa80996e36a7341b476418e9c3828a52b79 Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 6 Dec 2021 12:03:41 -0800 Subject: [PATCH 054/102] Fix weather response missing min and max temperature slots --- genienlp/validate.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index 03c7a00a..87803856 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -263,7 +263,7 @@ def generate_with_seq2seq_model_for_dialogue( if ( train_target == 'response' and re.search(rf'\( HKMTR {lang} \)', partial_batch_prediction) - and 'shortest_path' in partial_batch_prediction + and 'offer shortest_path equal_to' in partial_batch_prediction ): action_dict = span2action(partial_batch_prediction, api_names) domain = f'HKMTR {lang}' @@ -276,6 +276,22 @@ def generate_with_seq2seq_model_for_dialogue( partial_batch_prediction = action2span(action_dict[domain], domain, lang) + if ( + train_target == 'response' + and re.search(r'\( weathers search \)', partial_batch_prediction) + and 'offer weather equal_to' in partial_batch_prediction + ): + action_dict = span2action(partial_batch_prediction, api_names) + domain = 'weathers search' + weather_slots = set(item['slot'] for item in action_dict[domain]) + for slot in ['max_temp', 'min_temp']: + if knowledge and slot in knowledge[domain] and slot not in weather_slots: + action_dict[domain].append( + {'act': 'offer', 'slot': slot, 'relation': 'equal_to', 'value': [knowledge[domain][slot]]} + ) + + partial_batch_prediction = action2span(action_dict[domain], domain, lang) + partial_batch_prediction = task.postprocess_prediction(batch_example_ids[0], partial_batch_prediction) # put them into the right array From 239185927ee234257b0d761f193405a0d7455c9a Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 6 Dec 2021 12:19:26 -0800 Subject: [PATCH 055/102] Raise exceptions if user-specified data length limits don't agree with the dataset --- genienlp/util.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/genienlp/util.py b/genienlp/util.py index a273f27d..c1a15289 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -45,6 +45,8 @@ from transformers import MarianConfig, MBartConfig from transformers.models.mbart50.tokenization_mbart50 import FAIRSEQ_LANGUAGE_CODES +from genienlp.tasks.generic_dataset import all_tokens_fn, input_tokens_fn + from .data_utils.almond_utils import token_type_regex from .data_utils.example import NumericalizedExamples from .data_utils.iterator import LengthSortedIterator @@ -578,6 +580,9 @@ def make_data_loader(dataset, numericalizer, batch_size, device=None, train=Fals context_lengths = [ex.context.length for ex in all_features] answer_lengths = [ex.answer.length for ex in all_features] + min_output_length = numericalizer.args.min_output_length + max_output_length = numericalizer.args.max_output_length + logger.info( f'context lengths (min, mean, max): {np.min(context_lengths)}, {int(np.mean(context_lengths))}, {np.max(context_lengths)}' ) @@ -592,6 +597,29 @@ def make_data_loader(dataset, numericalizer, batch_size, device=None, train=Fals sort_key_fn = getattr(dataset, 'eval_sort_key_fn', dataset.sort_key_fn) batch_size_fn = getattr(dataset, 'eval_batch_size_fn', dataset.batch_size_fn) + if batch_size_fn == input_tokens_fn: + min_batch_length = np.min(context_lengths) + elif batch_size_fn == all_tokens_fn: + min_batch_length = np.min(context_lengths) + np.min(answer_lengths) + else: + min_batch_length = 1 + + if min_batch_length > batch_size: + raise ValueError( + f'The minimum example length in your dataset is {np.min(context_lengths) + np.min(answer_lengths)} but your batch size is {batch_size}.' + f' Thus no examples will be processed. Consider increasing batch_size' + ) + if np.min(answer_lengths) < min_output_length: + raise ValueError( + f'The minimum output length in your dataset is {np.min(answer_lengths)} but you have set --min_output_length to {min_output_length}.' + f' Consider reducing that' + ) + if np.max(answer_lengths) > max_output_length: + raise ValueError( + f'The maximum output length in your dataset is {np.max(answer_lengths)} but you have set --max_output_length to {max_output_length}.' + f' Consider increasing that' + ) + sampler = LengthSortedIterator( all_features, batch_size=batch_size, From e20b988ab65eaf93d847ca6122a4a5dcdfb2b35f Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 7 Dec 2021 22:42:35 -0800 Subject: [PATCH 056/102] Use casedbleu again for response in bitod_score --- genienlp/metrics.py | 18 +++++++++--------- genienlp/validate.py | 7 ++----- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index 17795c71..b7fd81a6 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -518,7 +518,7 @@ def computeBITOD(greedy, answer, tgt_lang): subtask_metrics_dict = defaultdict(tuple) subtasks = ['dst', 'api', 'response'] - subtask_metrics = [['em'], ['em'], ['em']] + subtask_metrics = [['em'], ['em'], ['casedbleu']] subtask_weights = [1, 1, 1] for t in range(len(subtasks)): @@ -532,7 +532,7 @@ def computeBITOD(greedy, answer, tgt_lang): subtask_metrics_dict[subtasks[t]] = (sub_metrics, len(golds), subtask_weights[t]) # TODO how should we aggregate? - bitod_score, JGA, response_em, api_em = 0.0, 0.0, 0.0, 0.0 + bitod_score, JGA, response_bleu, api_em = 0.0, 0.0, 0.0, 0.0 weighted_num_examples = 0 for subtask, (sub_metrics, num_ex, weight) in subtask_metrics_dict.items(): if subtask == 'dst': @@ -542,13 +542,13 @@ def computeBITOD(greedy, answer, tgt_lang): bitod_score += weight * (sub_metrics['em'] * num_ex) api_em = sub_metrics['em'] elif subtask == 'response': - bitod_score += weight * (sub_metrics['em'] * num_ex) - response_em = sub_metrics['em'] + bitod_score += weight * (sub_metrics['casedbleu'] * num_ex) + response_bleu = sub_metrics['casedbleu'] weighted_num_examples += weight * num_ex bitod_score /= weighted_num_examples - return bitod_score, JGA, response_em, api_em + return bitod_score, JGA, response_bleu, api_em def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): @@ -571,10 +571,10 @@ def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): if not isinstance(answer[0], list): answer = [[a] for a in answer] if 'bitod_score' in requested_metrics: - requested_metrics += ['JGA', 'response_em', 'api_em'] - bitod_score, JGA, response_em, api_em = computeBITOD(greedy, answer, lang) - metric_keys += ['bitod_score', 'JGA', 'response_em', 'api_em'] - metric_values += [bitod_score, JGA, response_em, api_em] + requested_metrics += ['JGA', 'response_bleu', 'api_em'] + bitod_score, JGA, response_bleu, api_em = computeBITOD(greedy, answer, lang) + metric_keys += ['bitod_score', 'JGA', 'response_bleu', 'api_em'] + metric_values += [bitod_score, JGA, response_bleu, api_em] if 'lfem' in requested_metrics: lfem, answer = computeLFEM(greedy, answer) metric_keys += ['lfem'] diff --git a/genienlp/validate.py b/genienlp/validate.py index 87803856..1b3f8417 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -39,7 +39,8 @@ import ujson from BiToD.evaluate import r_en_API_MAP from BiToD.knowledgebase import api -from BiToD.preprocess import API_MAP, knowledge2span, read_require_slots, state2span +from BiToD.knowledgebase.en_zh_mappings import api_names, required_slots +from BiToD.preprocess import knowledge2span, state2span from BiToD.utils import action2span, span2action, span2state, state2constraints from dateparser.languages import default_loader from transformers import MarianTokenizer @@ -150,10 +151,6 @@ def generate_with_seq2seq_model_for_dialogue( device = model.device - required_slots = read_require_slots() - required_slots = {API_MAP[k]: v for k, v in required_slots.items()} - api_names = list(required_slots.keys()) - special_tokens = numericalizer._tokenizer.all_special_tokens for k, turn in enumerate(progress_bar(data_iterator, desc='Generating', disable=disable_progbar)): From 58f8ce82fe57bf3b3c8aae4bef98cb03d8ace930 Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 7 Dec 2021 23:09:06 -0800 Subject: [PATCH 057/102] Minor import fix --- genienlp/dial_validate.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/genienlp/dial_validate.py b/genienlp/dial_validate.py index 4c8fdd89..f3a7e6ab 100644 --- a/genienlp/dial_validate.py +++ b/genienlp/dial_validate.py @@ -8,8 +8,8 @@ import ujson from BiToD.evaluate import r_en_API_MAP from BiToD.knowledgebase import api -from BiToD.knowledgebase.en_zh_mappings import API_MAP -from BiToD.utils import knowledge2span, read_require_slots, span2state, state2constraints, state2span +from BiToD.knowledgebase.en_zh_mappings import api_names, required_slots +from BiToD.utils import knowledge2span, span2state, state2constraints, state2span from termcolor import colored from genienlp.data_utils.example import NumericalizedExamples, SequentialField @@ -76,10 +76,6 @@ def generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e nlg_numericalizer = nlg_model.numericalizer nlg_args = nlg_model.args - required_slots = read_require_slots() - required_slots = {API_MAP[k]: v for k, v in required_slots.items()} - api_names = list(required_slots.keys()) - dial_id = 'none' turn_id = 1 dialogue_state = {} From 083849b689120f21be5930e8c3d02a8467d35f95 Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 13 Dec 2021 12:02:53 -0800 Subject: [PATCH 058/102] validate: handle no_state datasets --- genienlp/validate.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index 1b3f8417..9329f7bd 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -101,10 +101,13 @@ def generate_with_model( def replace_capturing_group(input, re_pattern, replacement): # replace first captured group in the input with replacement using regex re_pattern - whole_match = re_pattern.search(input).group(0).strip() - captured_match = re_pattern.search(input).group(1).strip() - new_whole_match = whole_match.replace(captured_match, replacement) - new_input = re.sub(re_pattern, new_whole_match, input) + if re_pattern.search(input): + whole_match = re_pattern.search(input).group(0).strip() + captured_match = re_pattern.search(input).group(1).strip() + new_whole_match = whole_match.replace(captured_match, replacement) + new_input = re.sub(re_pattern, new_whole_match, input) + else: + new_input = input return new_input @@ -345,10 +348,11 @@ def generate_with_seq2seq_model_for_dialogue( f'Message = No item available for api_name: {api_name}, constraints: {constraints},' f' processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' ) - gold_dial_state = span2state(state_re.search(contexts[-1]).group(1).strip(), api_names) - logger.warning( - f'state_diff: {list(dictdiffer.diff(dialogue_state[api_name], gold_dial_state[api_name]))}' - ) + if state_re.search(contexts[-1]): + gold_dial_state = span2state(state_re.search(contexts[-1]).group(1).strip(), api_names) + logger.warning( + f'state_diff: {list(dictdiffer.diff(dialogue_state[api_name], gold_dial_state[api_name]))}' + ) new_knowledge_text = f'( {api_name} ) Message = No item available.' else: From 938ee2a29687227563be6621929d99c1250ad200 Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 18 Jan 2022 22:08:04 -0800 Subject: [PATCH 059/102] generic_task: fix special token for bitod --- genienlp/tasks/generic_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 18e2d1e1..40e62d06 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -426,7 +426,7 @@ def __init__(self, name, args): special_tokens_v7 = {'ACTS:'} special_tokens_v9 = {'USER_ACTS:'} special_tokens_v11 = {'', '', ''} - special_tokens_v13 = {'PREV_AGENT_ACTS'} + special_tokens_v13 = {'AGENT_ACTS_PREV'} self.special_tokens = ( special_tokens_v1 | special_tokens_v2 From cc90472953953dab1b1b948cf80fcaf18c2b1523 Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 18 Jan 2022 22:09:07 -0800 Subject: [PATCH 060/102] generic_task: use bitod_score as main metric --- genienlp/tasks/generic_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 40e62d06..6a1fe3cd 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -436,7 +436,7 @@ def __init__(self, name, args): | special_tokens_v11 | special_tokens_v13 ) - self._metrics = ['em', 'casedbleu'] + self._metrics = ['bitod_score'] def utterance_field(self): return 'context' From 09e4e92e999957e0100b6f0f3455a046f8df418f Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 20 Jan 2022 13:21:51 -0800 Subject: [PATCH 061/102] Updates to accommodate bitod's 4-step pipeline --- genienlp/metrics.py | 23 +++++++++++++---------- genienlp/tasks/generic_task.py | 2 ++ genienlp/validate.py | 26 ++++++++++++++++++++++---- 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index b7fd81a6..12e71020 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -517,13 +517,13 @@ def computeBITOD(greedy, answer, tgt_lang): num_examples = len(answer) subtask_metrics_dict = defaultdict(tuple) - subtasks = ['dst', 'api', 'response'] - subtask_metrics = [['em'], ['em'], ['casedbleu']] - subtask_weights = [1, 1, 1] + subtasks = ['dst', 'api', 'da', 'rg'] + subtask_metrics = [['em'], ['em'], ['em'], ['casedbleu']] + subtask_weights = [1, 1, 1, 1] for t in range(len(subtasks)): preds, golds = [], [] - for i in range(t, num_examples, 3): + for i in range(t, num_examples, len(subtasks)): preds.append(greedy[i]) golds.append(answer[i]) @@ -532,7 +532,7 @@ def computeBITOD(greedy, answer, tgt_lang): subtask_metrics_dict[subtasks[t]] = (sub_metrics, len(golds), subtask_weights[t]) # TODO how should we aggregate? - bitod_score, JGA, response_bleu, api_em = 0.0, 0.0, 0.0, 0.0 + bitod_score, JGA, api_em, da_em, response_bleu = 0.0, 0.0, 0.0, 0.0, 0.0 weighted_num_examples = 0 for subtask, (sub_metrics, num_ex, weight) in subtask_metrics_dict.items(): if subtask == 'dst': @@ -541,14 +541,17 @@ def computeBITOD(greedy, answer, tgt_lang): elif subtask == 'api': bitod_score += weight * (sub_metrics['em'] * num_ex) api_em = sub_metrics['em'] - elif subtask == 'response': + elif subtask == 'da': + bitod_score += weight * (sub_metrics['em'] * num_ex) + da_em = sub_metrics['em'] + elif subtask == 'rg': bitod_score += weight * (sub_metrics['casedbleu'] * num_ex) response_bleu = sub_metrics['casedbleu'] weighted_num_examples += weight * num_ex bitod_score /= weighted_num_examples - return bitod_score, JGA, response_bleu, api_em + return bitod_score, JGA, api_em, da_em, response_bleu def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): @@ -572,9 +575,9 @@ def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): answer = [[a] for a in answer] if 'bitod_score' in requested_metrics: requested_metrics += ['JGA', 'response_bleu', 'api_em'] - bitod_score, JGA, response_bleu, api_em = computeBITOD(greedy, answer, lang) - metric_keys += ['bitod_score', 'JGA', 'response_bleu', 'api_em'] - metric_values += [bitod_score, JGA, response_bleu, api_em] + bitod_score, JGA, api_em, da_em, response_bleu = computeBITOD(greedy, answer, lang) + metric_keys += ['bitod_score', 'JGA', 'api_em', 'da_em', 'response_bleu'] + metric_values += [bitod_score, JGA, api_em, da_em, response_bleu] if 'lfem' in requested_metrics: lfem, answer = computeLFEM(greedy, answer) metric_keys += ['lfem'] diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 6a1fe3cd..974e347a 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -427,6 +427,7 @@ def __init__(self, name, args): special_tokens_v9 = {'USER_ACTS:'} special_tokens_v11 = {'', '', ''} special_tokens_v13 = {'AGENT_ACTS_PREV'} + special_tokens_v2_10 = {'', ''} self.special_tokens = ( special_tokens_v1 | special_tokens_v2 @@ -435,6 +436,7 @@ def __init__(self, name, args): | special_tokens_v9 | special_tokens_v11 | special_tokens_v13 + | special_tokens_v2_10 ) self._metrics = ['bitod_score'] diff --git a/genienlp/validate.py b/genienlp/validate.py index 9329f7bd..d044bc2c 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -138,6 +138,7 @@ def generate_with_seq2seq_model_for_dialogue( state_re = re.compile(' (.*?)(?:$|<)') knowledge_re = re.compile(' (.*?)(?:$|<)') + actions_re = re.compile(' (.*?)(?:$|<)') bitod_preds = dict() @@ -173,6 +174,7 @@ def generate_with_seq2seq_model_for_dialogue( dialogue_state = {} new_state_text = 'null' new_knowledge_text = 'null' + new_actions_text = 'null' active_api = None bitod_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)} @@ -215,13 +217,23 @@ def generate_with_seq2seq_model_for_dialogue( # replace state input_text = replace_capturing_group(contexts[-1], state_re, new_state_text) - elif train_target == 'response': + elif train_target == 'da': # replace state input_text = replace_capturing_group(contexts[-1], state_re, new_state_text) # replace knowledge input_text = replace_capturing_group(input_text, knowledge_re, new_knowledge_text) + elif train_target == 'rg': + # replace state + # input_text = replace_capturing_group(contexts[-1], state_re, new_state_text) + + # replace knowledge + # input_text = replace_capturing_group(input_text, knowledge_re, new_knowledge_text) + + # replace actions + input_text = replace_capturing_group(contexts[-1], actions_re, new_actions_text) + else: raise ValueError(f'Invalid train_target: {train_target}') @@ -261,7 +273,7 @@ def generate_with_seq2seq_model_for_dialogue( # post-process predictions lang = numericalizer._tokenizer.src_lang[:2] if ( - train_target == 'response' + train_target == 'acts' and re.search(rf'\( HKMTR {lang} \)', partial_batch_prediction) and 'offer shortest_path equal_to' in partial_batch_prediction ): @@ -277,7 +289,7 @@ def generate_with_seq2seq_model_for_dialogue( partial_batch_prediction = action2span(action_dict[domain], domain, lang) if ( - train_target == 'response' + train_target == 'acts' and re.search(r'\( weathers search \)', partial_batch_prediction) and 'offer weather equal_to' in partial_batch_prediction ): @@ -376,7 +388,13 @@ def generate_with_seq2seq_model_for_dialogue( bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text #### - if train_target == 'response': + elif train_target == 'da': + new_actions_text = predictions[-1] + #### save latest actions + bitod_preds[dial_id]["turns"][str(turn_id)]["actions"] = predictions[-1] + #### + + elif train_target == 'rg': #### save latest response bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = predictions[-1] #### From fdf707f0d1bbc87f0958315482dd4d547a69fd59 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 20 Jan 2022 15:12:13 -0800 Subject: [PATCH 062/102] Update bitod special tokens --- genienlp/tasks/generic_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 974e347a..36ccd242 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -427,7 +427,7 @@ def __init__(self, name, args): special_tokens_v9 = {'USER_ACTS:'} special_tokens_v11 = {'', '', ''} special_tokens_v13 = {'AGENT_ACTS_PREV'} - special_tokens_v2_10 = {'', ''} + special_tokens_v2_10 = {'', '', 'DA:', 'RG:'} self.special_tokens = ( special_tokens_v1 | special_tokens_v2 From 77eb28e370e40a421e3c0bd9751243522af48d8e Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 20 Jan 2022 17:20:39 -0800 Subject: [PATCH 063/102] validate: fix syntax bug --- genienlp/validate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index d044bc2c..3e8826c5 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -389,9 +389,9 @@ def generate_with_seq2seq_model_for_dialogue( #### elif train_target == 'da': - new_actions_text = predictions[-1] + new_actions_text = predictions[-1][0] #### save latest actions - bitod_preds[dial_id]["turns"][str(turn_id)]["actions"] = predictions[-1] + bitod_preds[dial_id]["turns"][str(turn_id)]["actions"] = predictions[-1][0] #### elif train_target == 'rg': From 6929b0fa63a43e3fb3dba37839bf8ff1d1f9ab09 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 21 Jan 2022 15:00:51 -0800 Subject: [PATCH 064/102] Do not sort valid/ test for bitod --- genienlp/tasks/generic_dataset.py | 5 ++++- genienlp/validate.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/genienlp/tasks/generic_dataset.py b/genienlp/tasks/generic_dataset.py index 8bbd351d..a6857359 100644 --- a/genienlp/tasks/generic_dataset.py +++ b/genienlp/tasks/generic_dataset.py @@ -1988,8 +1988,11 @@ def __init__(self, path, *, make_example, **kwargs): super().__init__(examples, **kwargs) + # do not sort eval/ test set so we can compute individual scores for each subtask (bitod_score) + self.eval_sort_key_fn = None + + # in e2e evaluation use 1 batch at a time if kwargs.get('e2e_evaluation', False): - self.eval_sort_key_fn = None self.eval_batch_size_fn = default_batch_fn @classmethod diff --git a/genienlp/validate.py b/genienlp/validate.py index 3e8826c5..5c75ed4a 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -381,7 +381,7 @@ def generate_with_seq2seq_model_for_dialogue( pass else: logger.error( - f'API call should be either yes or no but got {do_api_call}. Seems model is still training so we assume it\'s a no' + f'API call should be either yes or no but got {do_api_call}. Seems model is not trained for enough steps. For now we assume it\'s a no' ) #### save latest api results From 7d01e870747b53847b9f1f09e0413d711abaed52 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 21 Jan 2022 15:08:56 -0800 Subject: [PATCH 065/102] metrics: update names --- genienlp/metrics.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index 12e71020..eaccbbef 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -532,7 +532,7 @@ def computeBITOD(greedy, answer, tgt_lang): subtask_metrics_dict[subtasks[t]] = (sub_metrics, len(golds), subtask_weights[t]) # TODO how should we aggregate? - bitod_score, JGA, api_em, da_em, response_bleu = 0.0, 0.0, 0.0, 0.0, 0.0 + bitod_score, JGA, API_em, DA_em, BLEU = 0.0, 0.0, 0.0, 0.0, 0.0 weighted_num_examples = 0 for subtask, (sub_metrics, num_ex, weight) in subtask_metrics_dict.items(): if subtask == 'dst': @@ -540,18 +540,18 @@ def computeBITOD(greedy, answer, tgt_lang): JGA = sub_metrics['em'] elif subtask == 'api': bitod_score += weight * (sub_metrics['em'] * num_ex) - api_em = sub_metrics['em'] + API_em = sub_metrics['em'] elif subtask == 'da': bitod_score += weight * (sub_metrics['em'] * num_ex) - da_em = sub_metrics['em'] + DA_em = sub_metrics['em'] elif subtask == 'rg': bitod_score += weight * (sub_metrics['casedbleu'] * num_ex) - response_bleu = sub_metrics['casedbleu'] + BLEU = sub_metrics['casedbleu'] weighted_num_examples += weight * num_ex bitod_score /= weighted_num_examples - return bitod_score, JGA, api_em, da_em, response_bleu + return bitod_score, JGA, API_em, DA_em, BLEU def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): @@ -574,10 +574,10 @@ def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): if not isinstance(answer[0], list): answer = [[a] for a in answer] if 'bitod_score' in requested_metrics: - requested_metrics += ['JGA', 'response_bleu', 'api_em'] - bitod_score, JGA, api_em, da_em, response_bleu = computeBITOD(greedy, answer, lang) - metric_keys += ['bitod_score', 'JGA', 'api_em', 'da_em', 'response_bleu'] - metric_values += [bitod_score, JGA, api_em, da_em, response_bleu] + requested_metrics += ['JGA', 'API_em', 'DA_em', 'BLEU'] + bitod_score, JGA, API_em, DA_em, BLEU = computeBITOD(greedy, answer, lang) + metric_keys += ['bitod_score', 'JGA', 'API_em', 'DA_em', 'BLEU'] + metric_values += [bitod_score, JGA, API_em, DA_em, BLEU] if 'lfem' in requested_metrics: lfem, answer = computeLFEM(greedy, answer) metric_keys += ['lfem'] From a1c54df773587be38aed2d155df08a95df853537 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 21 Jan 2022 15:27:57 -0800 Subject: [PATCH 066/102] validate: fix bug (acts --> da) --- genienlp/validate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index 5c75ed4a..633fffe4 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -273,7 +273,7 @@ def generate_with_seq2seq_model_for_dialogue( # post-process predictions lang = numericalizer._tokenizer.src_lang[:2] if ( - train_target == 'acts' + train_target == 'da' and re.search(rf'\( HKMTR {lang} \)', partial_batch_prediction) and 'offer shortest_path equal_to' in partial_batch_prediction ): @@ -289,7 +289,7 @@ def generate_with_seq2seq_model_for_dialogue( partial_batch_prediction = action2span(action_dict[domain], domain, lang) if ( - train_target == 'acts' + train_target == 'da' and re.search(r'\( weathers search \)', partial_batch_prediction) and 'offer weather equal_to' in partial_batch_prediction ): From 88eb5e28709722be130dc56c15d069b138891dd6 Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 1 Feb 2022 16:33:33 -0800 Subject: [PATCH 067/102] almond_translate: remove unnecesary data processing if alignment is not used --- genienlp/tasks/almond_task.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/genienlp/tasks/almond_task.py b/genienlp/tasks/almond_task.py index a2ed971c..e1de6b81 100644 --- a/genienlp/tasks/almond_task.py +++ b/genienlp/tasks/almond_task.py @@ -383,7 +383,7 @@ def _make_example(self, parts, dir_name=None, **kwargs): contexts = [] src_char_spans = None if split_sentence: - if self.args.do_alignment: + if self.need_attention_scores: src_quotation_symbol = '"' src_char_spans_ind = [index for index, char in enumerate(context) if char == src_quotation_symbol] src_char_spans = [ @@ -394,7 +394,9 @@ def _make_example(self, parts, dir_name=None, **kwargs): if len(contexts) > 1: examples = [] for i, text in enumerate(contexts): - ex_id, text = self.construct_id2span_mapping(self.name + '/' + example_id + f'@{i}', text, 'context') + ex_id = self.name + '/' + example_id + f'@{i}' + if self.need_attention_scores: + ex_id, text = self.construct_id2span_mapping(ex_id, text, 'context') examples.append( Example.from_raw( ex_id, @@ -406,7 +408,9 @@ def _make_example(self, parts, dir_name=None, **kwargs): ) ) else: - ex_id, context = self.construct_id2span_mapping(self.name + '/' + example_id, context, 'context') + ex_id = self.name + '/' + example_id + if self.need_attention_scores: + ex_id, context = self.construct_id2span_mapping(ex_id, context, 'context') examples = Example.from_raw(ex_id, context, question, answer, preprocess=self.preprocess_field, lower=False) return examples @@ -480,7 +484,7 @@ def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, bat ) plt.show() - if self.args.do_alignment: + if self.need_attention_scores: src_spans = self.input_spans[example_id] try: text = align_and_replace( From bd18f3b4ea13b01227918f6258d9e0124e11a26e Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 7 Feb 2022 12:42:35 -0800 Subject: [PATCH 068/102] Switch to dialogues library --- genienlp/validate.py | 108 ++++++++----------------------------------- setup.py | 4 +- 2 files changed, 20 insertions(+), 92 deletions(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index 633fffe4..caecd226 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -34,15 +34,10 @@ import sys from collections import defaultdict -import dictdiffer import torch import ujson -from BiToD.evaluate import r_en_API_MAP -from BiToD.knowledgebase import api -from BiToD.knowledgebase.en_zh_mappings import api_names, required_slots -from BiToD.preprocess import knowledge2span, state2span -from BiToD.utils import action2span, span2action, span2state, state2constraints from dateparser.languages import default_loader +from dialogues import Bitod from transformers import MarianTokenizer from .data_utils.example import NumericalizedExamples, SequentialField @@ -133,13 +128,7 @@ def generate_with_seq2seq_model_for_dialogue( contexts """ - # history_re = re.compile(' (.*?)(?:$|<)') - # last_system_re = re.compile('SYSTEM: (.*?)(?:USER:|$)') - - state_re = re.compile(' (.*?)(?:$|<)') - knowledge_re = re.compile(' (.*?)(?:$|<)') - actions_re = re.compile(' (.*?)(?:$|<)') - + dataset = Bitod() bitod_preds = dict() predictions = [] @@ -204,7 +193,7 @@ def generate_with_seq2seq_model_for_dialogue( answers += batch_answer if train_target == 'dst': - input_text = replace_capturing_group(contexts[-1], state_re, new_state_text) + input_text = replace_capturing_group(contexts[-1], dataset.state_re, new_state_text) ## we always use gold history following common practice ## if you want to use predicted response instead of gold uncomment the following @@ -212,17 +201,17 @@ def generate_with_seq2seq_model_for_dialogue( # input_text = replace_match(input_text, last_system_re, last_sys_pred) elif train_target == 'api': - new_state_text = state2span(dialogue_state, required_slots) + new_state_text = dataset.state2span(dialogue_state) # replace state - input_text = replace_capturing_group(contexts[-1], state_re, new_state_text) + input_text = replace_capturing_group(contexts[-1], dataset.state_re, new_state_text) elif train_target == 'da': # replace state - input_text = replace_capturing_group(contexts[-1], state_re, new_state_text) + input_text = replace_capturing_group(contexts[-1], dataset.state_re, new_state_text) # replace knowledge - input_text = replace_capturing_group(input_text, knowledge_re, new_knowledge_text) + input_text = replace_capturing_group(input_text, dataset.knowledge_re, new_knowledge_text) elif train_target == 'rg': # replace state @@ -232,7 +221,7 @@ def generate_with_seq2seq_model_for_dialogue( # input_text = replace_capturing_group(input_text, knowledge_re, new_knowledge_text) # replace actions - input_text = replace_capturing_group(contexts[-1], actions_re, new_actions_text) + input_text = replace_capturing_group(contexts[-1], dataset.actions_re, new_actions_text) else: raise ValueError(f'Invalid train_target: {train_target}') @@ -270,39 +259,10 @@ def generate_with_seq2seq_model_for_dialogue( partial_batch_prediction = numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] - # post-process predictions - lang = numericalizer._tokenizer.src_lang[:2] - if ( - train_target == 'da' - and re.search(rf'\( HKMTR {lang} \)', partial_batch_prediction) - and 'offer shortest_path equal_to' in partial_batch_prediction - ): - action_dict = span2action(partial_batch_prediction, api_names) - domain = f'HKMTR {lang}' - metro_slots = set(item['slot'] for item in action_dict[domain]) - for slot in ['estimated_time', 'price']: - if knowledge and slot in knowledge[domain] and slot not in metro_slots: - action_dict[domain].append( - {'act': 'offer', 'slot': slot, 'relation': 'equal_to', 'value': [knowledge[domain][slot]]} - ) - - partial_batch_prediction = action2span(action_dict[domain], domain, lang) - - if ( - train_target == 'da' - and re.search(r'\( weathers search \)', partial_batch_prediction) - and 'offer weather equal_to' in partial_batch_prediction - ): - action_dict = span2action(partial_batch_prediction, api_names) - domain = 'weathers search' - weather_slots = set(item['slot'] for item in action_dict[domain]) - for slot in ['max_temp', 'min_temp']: - if knowledge and slot in knowledge[domain] and slot not in weather_slots: - action_dict[domain].append( - {'act': 'offer', 'slot': slot, 'relation': 'equal_to', 'value': [knowledge[domain][slot]]} - ) - - partial_batch_prediction = action2span(action_dict[domain], domain, lang) + if train_target == 'da': + partial_batch_prediction = dataset.postprocess_prediction( + partial_batch_prediction, knowledge, lang=numericalizer._tokenizer.src_lang[:2] + ) partial_batch_prediction = task.postprocess_prediction(batch_example_ids[0], partial_batch_prediction) @@ -314,7 +274,7 @@ def generate_with_seq2seq_model_for_dialogue( if train_target == 'dst': # update dialogue_state lev = predictions[-1][0].strip() - state_update = span2state(lev, api_names) + state_update = dataset.span2state(lev) for api_name in state_update: active_api = api_name if api_name not in dialogue_state: @@ -324,7 +284,7 @@ def generate_with_seq2seq_model_for_dialogue( #### save latest state state_to_record = copy.deepcopy(dialogue_state) - state_to_record = {r_en_API_MAP.get(k, k): v for k, v in state_to_record.items()} + state_to_record = {dataset.domain2api_name(k): v for k, v in state_to_record.items()} bitod_preds[dial_id]["turns"][str(turn_id)]["state"] = state_to_record #### @@ -337,43 +297,11 @@ def generate_with_seq2seq_model_for_dialogue( api_name = active_api if api_name in dialogue_state: - constraints = state2constraints(dialogue_state[api_name]) - # domain = api_name.split(" ")[0] - knowledge = defaultdict(dict) - - msg = [0, 0, 0] - try: - msg = api.call_api( - r_en_API_MAP.get(api_name, api_name), - constraints=[constraints], - lang=numericalizer._tokenizer.src_lang, - ) - except Exception as e: - logger.error(f'Error: {e}') - logger.error( - f'Failed API call with api_name: {api_name}, constraints: {constraints},' - f' processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' - ) - - if int(msg[1]) <= 0: - logger.warning( - f'Message = No item available for api_name: {api_name}, constraints: {constraints},' - f' processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' - ) - if state_re.search(contexts[-1]): - gold_dial_state = span2state(state_re.search(contexts[-1]).group(1).strip(), api_names) - logger.warning( - f'state_diff: {list(dictdiffer.diff(dialogue_state[api_name], gold_dial_state[api_name]))}' - ) - - new_knowledge_text = f'( {api_name} ) Message = No item available.' - else: - # always choose highest ranking results (having deterministic api results) - knowledge[api_name].update(msg[0]) - new_knowledge_text = knowledge2span(knowledge) - + constraints, new_knowledge_text = dataset.make_api_call( + dialogue_state, api_name, numericalizer._tokenizer.src_lang, dial_id, turn_id + ) #### save latest api constraints - bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints) + bitod_preds[dial_id]["API"][dataset.domain2api_name(api_name)] = copy.deepcopy(constraints) #### elif do_api_call == 'no': diff --git a/setup.py b/setup.py index 2e01089e..456e76b1 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ 'scikit-learn>=0.23,<2.0', 'dill~=0.3', 'xgboost~=1.3', - # BiToD - 'BiToD @ git+ssh://git@github.com/Mehrad0711/BiToD@main', + # dialogues + 'dialogues @ git+https://github.com/Mehrad0711/dialogues', ], ) From 8fb0797b32dff577ad0a879692381b0d66333049 Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 7 Feb 2022 12:37:18 -0800 Subject: [PATCH 069/102] delete dialogue loop --- genienlp/__main__.py | 16 +-- genienlp/dial_validate.py | 243 ---------------------------------- genienlp/run_dialogue_loop.py | 172 ------------------------ genienlp/validate.py | 7 - 4 files changed, 1 insertion(+), 437 deletions(-) delete mode 100644 genienlp/dial_validate.py delete mode 100644 genienlp/run_dialogue_loop.py diff --git a/genienlp/__main__.py b/genienlp/__main__.py index 60e38d7a..aef196dc 100755 --- a/genienlp/__main__.py +++ b/genienlp/__main__.py @@ -30,19 +30,7 @@ import argparse -from . import ( - arguments, - cache_embeddings, - calibrate, - export, - kfserver, - predict, - run_bootleg, - run_dialogue_loop, - server, - train, - write_kf_metrics, -) +from . import arguments, cache_embeddings, calibrate, export, kfserver, predict, run_bootleg, server, train, write_kf_metrics from .ned.scripts import analyze_bootleg_results, oracle_vs_bootleg from .paraphrase import run_generation, run_lm_finetuning from .paraphrase.scripts import clean_paraphrasing_dataset, dialog_to_tsv, split_dataset, transform_dataset @@ -106,8 +94,6 @@ # kf commands 'kfserver': ('Export KFServing interface to predict', server.parse_argv, kfserver.main), 'write-kf-metrics': ('Write KF evaluation metrics', write_kf_metrics.parse_argv, write_kf_metrics.main), - # bitod - 'run-dialogue-loop': ('Interact with dialogue agent', run_dialogue_loop.parse_argv, run_dialogue_loop.main), } diff --git a/genienlp/dial_validate.py b/genienlp/dial_validate.py deleted file mode 100644 index f3a7e6ab..00000000 --- a/genienlp/dial_validate.py +++ /dev/null @@ -1,243 +0,0 @@ -import copy -import logging -import sys -import time -from collections import defaultdict - -import torch -import ujson -from BiToD.evaluate import r_en_API_MAP -from BiToD.knowledgebase import api -from BiToD.knowledgebase.en_zh_mappings import api_names, required_slots -from BiToD.utils import knowledge2span, span2state, state2constraints, state2span -from termcolor import colored - -from genienlp.data_utils.example import NumericalizedExamples, SequentialField - -logger = logging.getLogger(__name__) - -INIT_SYS_MESSAGE = { - 'en': 'Hello! How can I help you today?', - 'fa': 'Ų³Ł„Ų§Ł…! Ų§Ł…Ų±ŁˆŲ² Ś†Ų·ŁˆŲ± Ł…ŪŒ ŲŖŁˆŲ§Ł†Ł… ŲØŁ‡ Ų“Ł…Ų§ Ś©Ł…Ś© Ś©Ł†Ł…ŲŸ', - 'zh': '你儽ļ¼ ęˆ‘ä»Šå¤©čƒ½åø®åˆ°ä½ ä»€ä¹ˆļ¼Ÿ', -} - - -def numericalize_example(input_text, numericalizer, turn_id, device): - if isinstance(input_text, str): - input_text = [input_text] - tokenized_contexts = numericalizer.encode_batch(input_text, field_name='context', features=None)[0] - - numericalized_turn = NumericalizedExamples( - example_id=[str(turn_id)], - context=SequentialField( - value=torch.tensor([tokenized_contexts.value], device=device), - length=torch.tensor([tokenized_contexts.length], device=device), - limited=torch.tensor([tokenized_contexts.limited], device=device), - feature=None, - ), - answer=SequentialField(value=None, length=None, limited=None, feature=None), - ) - - return numericalized_turn - - -def generate(model, args, numericalized_turn, hyperparameter_idx): - return model.generate( - numericalized_turn, - max_output_length=args.max_output_length, - min_output_length=args.min_output_length, - num_outputs=args.num_outputs[hyperparameter_idx], - temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, - repetition_penalty=args.repetition_penalty[hyperparameter_idx], - top_k=args.top_k[hyperparameter_idx], - top_p=args.top_p[hyperparameter_idx], - num_beams=args.num_beams[hyperparameter_idx], - num_beam_groups=args.num_beam_groups[hyperparameter_idx], - diversity_penalty=args.diversity_penalty[hyperparameter_idx], - no_repeat_ngram_size=args.no_repeat_ngram_size[hyperparameter_idx], - do_sample=args.temperature[hyperparameter_idx] != 0, - ) - - -def generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e2e_task, nlg_task): - - bitod_preds = dict() - - predictions = [] - - e2e_numericalizer = e2e_model.numericalizer - e2e_args = e2e_model.args - device = e2e_model.device - - tgt_lang = e2e_model.tgt_lang - - if e2e_args.nlg_type == 'neural': - nlg_numericalizer = nlg_model.numericalizer - nlg_args = nlg_model.args - - dial_id = 'none' - turn_id = 1 - dialogue_state = {} - new_state_text = 'null' - new_knowledge_text = 'null' - active_api = None - bitod_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)} - - convo_history = [] - nlg_responses = [] - convo_window = 3 - - hyperparameter_idx = 0 - - train_target = 'response' - - next_target = {'dst': 'api', 'api': 'response', 'response': 'dst'} - - while True: - try: - batch_prediction = [] - - # becomes dst for first turn - train_target = next_target[train_target] - - if train_target == 'dst': - if convo_history: - print(colored(f'SYSTEM: {nlg_responses[-1]}', 'red', attrs=['bold'])) - else: - tgt_lang = tgt_lang[:2] - print(colored(f'SYSTEM: {INIT_SYS_MESSAGE[tgt_lang]}', 'red', attrs=['bold'])) - - # construct new input - raw_user_input = input(colored('USER: ', 'green', attrs=['bold'])) - if raw_user_input == 'RESET': - generate_with_seq2seq_model_for_dialogue_interactive(e2e_model, nlg_model, e2e_task, nlg_task) - break - elif raw_user_input == 'END': - sys.exit(0) - elif raw_user_input == 'STATE': - print(f'dialogue state: {dialogue_state}') - continue - - raw_user_input = 'USER: ' + raw_user_input.strip() - - convo_history.append(raw_user_input) - - input_text = f'DST: {new_state_text} {" ".join(convo_history[-convo_window:])}' - - elif train_target == 'api': - new_state_text = state2span(dialogue_state, required_slots) - - # replace state - input_text = f'API: {new_state_text} {" ".join(convo_history[-convo_window:])}' - - elif train_target == 'response': - - input_text = f'Response: {new_knowledge_text} {new_state_text} {" ".join(convo_history[-convo_window:])}' - - else: - raise ValueError(f'Invalid train_target: {train_target}') - - numericalized_turn = numericalize_example(input_text, e2e_numericalizer, turn_id, device) - generated = generate(e2e_model, e2e_args, numericalized_turn, hyperparameter_idx) - - partial_batch_prediction_ids = generated.sequences - - partial_batch_prediction = e2e_numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] - - # post-process predictions - partial_batch_prediction = e2e_task.postprocess_prediction(turn_id, partial_batch_prediction) - - # put them into the right array - batch_prediction.append([partial_batch_prediction]) - - predictions += batch_prediction - - if train_target == 'dst': - # update dialogue_state - lev = predictions[-1][0].strip() - state_update = span2state(lev, api_names) - for api_name in state_update: - active_api = api_name - if api_name not in dialogue_state: - dialogue_state[api_name] = state_update[api_name] - else: - dialogue_state[api_name].update(state_update[api_name]) - - #### save latest state - state_to_record = copy.deepcopy(dialogue_state) - state_to_record = {r_en_API_MAP.get(k, k): v for k, v in state_to_record.items()} - bitod_preds[dial_id]["turns"][str(turn_id)]["state"] = state_to_record - #### - - elif train_target == 'api': - new_knowledge_text = 'null' - do_api_call = predictions[-1][0].strip() - - if do_api_call == 'yes': - # make api call - api_name = active_api - - if api_name in dialogue_state: - constraints = state2constraints(dialogue_state[api_name]) - # domain = api_name.split(" ")[0] - knowledge = defaultdict(dict) - - try: - msg = api.call_api(r_en_API_MAP.get(api_name, api_name), constraints=[constraints]) - except Exception as e: - logger.error(f'Error: {e}') - logger.error( - f'Failed API call with api_name: {api_name}, constraints: {constraints},' - f' processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' - ) - msg = [0, 0, 0] - - if int(msg[1]) <= 0: - logger.warning( - f'Message = No item available for api_name: {api_name}, constraints: {constraints},' - f' processed_query: {msg[2]}, for turn: {dial_id}/{turn_id}' - ) - - new_knowledge_text = f'( {api_name} ) Message = No item available.' - else: - # always choose highest ranking results (having deterministic api results) - knowledge[api_name].update(msg[0]) - new_knowledge_text = knowledge2span(knowledge) - - #### save latest api constraints - bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints) - #### - - #### save latest api results and constraints - bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text - #### - - if train_target == 'response': - # turn dialogue acts into actual responses - if e2e_args.nlg_type == 'neural': - numericalized_turn = numericalize_example(predictions[-1][0], e2e_numericalizer, turn_id, device) - generated = generate(nlg_model, nlg_args, numericalized_turn, hyperparameter_idx) - - partial_batch_prediction_ids = generated.sequences - - partial_batch_prediction = nlg_numericalizer.reverse(partial_batch_prediction_ids, 'answer')[0] - - # post-process predictions - partial_batch_prediction = nlg_task.postprocess_prediction(turn_id, partial_batch_prediction) - else: - partial_batch_prediction = nlg_model.generate(predictions[-1][0]) - - nlg_responses.append(partial_batch_prediction) - - #### save latest response - bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = nlg_responses[-1] - #### - - convo_history.append('SYSTEM: ' + predictions[-1][0]) - - except KeyboardInterrupt: - break - - with open(f"{int(time.time())}_bitod_preds.json", 'w') as fout: - ujson.dump(bitod_preds, fout, indent=2, ensure_ascii=False) diff --git a/genienlp/run_dialogue_loop.py b/genienlp/run_dialogue_loop.py deleted file mode 100644 index d93ccb41..00000000 --- a/genienlp/run_dialogue_loop.py +++ /dev/null @@ -1,172 +0,0 @@ -# -# Copyright (c) 2018, Salesforce, Inc. -# The Board of Trustees of the Leland Stanford Junior University -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import copy -import logging -from pprint import pformat - -import torch -from BiToD.templates.generate_template_response import TemplateResponseGenerator - -from genienlp.dial_validate import generate_with_seq2seq_model_for_dialogue_interactive - -from . import models -from .arguments import check_and_update_generation_args -from .tasks.registry import get_tasks -from .util import get_devices, load_config_json, set_seed - -logger = logging.getLogger(__name__) - - -def parse_argv(parser): - parser.add_argument('--e2e_path', type=str, required=True) - parser.add_argument('--nlg_path', type=str) - parser.add_argument( - '--nlg_type', type=str, choices=['neural', 'template-translated', 'template-human'], default='template' - ) - - parser.add_argument( - '--devices', default=[0], nargs='+', type=int, help='a list of devices that can be used (multi-gpu currently WIP)' - ) - parser.add_argument('--seed', default=123, type=int, help='Random seed.') - parser.add_argument('--embeddings', default='.embeddings', type=str, help='where to save embeddings.') - parser.add_argument( - '--checkpoint_name', default='best.pth', help='Checkpoint file to use (relative to --path, defaults to best.pth)' - ) - - parser.add_argument('--database_dir', type=str, help='Database folder containing all relevant files') - parser.add_argument('--src_locale', default='en', help='locale tag of the input language to parse') - parser.add_argument('--tgt_locale', default='en', help='locale tag of the target language to generate') - parser.add_argument('--inference_name', default='nlp', help='name used by kfserving inference service, alphanumeric only') - - # These are generation hyperparameters. Each one can be a list of values in which case, we generate `num_outputs` outputs for each set of hyperparameters. - parser.add_argument("--num_outputs", type=int, nargs='+', default=[1], help='number of sequences to output per input') - parser.add_argument("--temperature", type=float, nargs='+', default=[0.0], help="temperature of 0 implies greedy sampling") - parser.add_argument( - "--repetition_penalty", - type=float, - nargs='+', - default=[1.0], - help="primarily useful for CTRL model; in that case, use 1.2", - ) - parser.add_argument("--top_k", type=int, nargs='+', default=[0], help='0 disables top-k filtering') - parser.add_argument("--top_p", type=float, nargs='+', default=[1.0], help='1.0 disables top-p filtering') - parser.add_argument("--num_beams", type=int, nargs='+', default=[1], help='1 disables beam seach') - parser.add_argument("--num_beam_groups", type=int, nargs='+', default=[1], help='1 disables diverse beam seach') - parser.add_argument("--diversity_penalty", type=float, nargs='+', default=[0.0], help='0 disables diverse beam seach') - parser.add_argument( - "--no_repeat_ngram_size", - type=int, - nargs='+', - default=[0], - help='ngrams of this size cannot be repeated in the output. 0 disables it.', - ) - parser.add_argument('--max_output_length', default=150, type=int, help='maximum output length for generation') - - -class DialogueLoop(object): - def __init__(self, e2e_model, nlg_model): - self.e2e_model = e2e_model - self.nlg_model = nlg_model - - def run(self): - e2e_task = list(get_tasks(['bitod'], self.e2e_model.args).values())[0] - self.e2e_model.add_new_vocab_from_data([e2e_task]) - self.e2e_model.set_generation_output_options([e2e_task]) - - if self.e2e_model.args.nlg_type == 'neural': - nlg_task = list(get_tasks(['bitod_nlg'], self.nlg_model.args).values())[0] - self.nlg_model.add_new_vocab_from_data([nlg_task]) - self.nlg_model.set_generation_output_options([nlg_task]) - else: - nlg_task = None - - with torch.no_grad(): - generate_with_seq2seq_model_for_dialogue_interactive( - self.e2e_model, - self.nlg_model, - e2e_task, - nlg_task, - ) - - -def init(args): - set_seed(args) - - devices = get_devices() - device = devices[0] # server only runs on a single device - - e2e_args = copy.deepcopy(args) - e2e_args.path = args.e2e_path - load_config_json(e2e_args) - check_and_update_generation_args(e2e_args) - - E2EModel = getattr(models, e2e_args.model) - e2e_model, _ = E2EModel.load( - e2e_args.path, - model_checkpoint_file=e2e_args.checkpoint_name, - args=e2e_args, - device=device, - src_lang=e2e_args.src_locale, - tgt_lang=e2e_args.tgt_locale, - ) - e2e_model.to(device) - e2e_model.eval() - logger.info(f'Arguments:\n{pformat(vars(e2e_args))}') - logger.info(f'Loading from {e2e_args.best_checkpoint}') - - if args.nlg_type == 'neural': - nlg_args = copy.deepcopy(args) - nlg_args.path = args.nlg_path - load_config_json(nlg_args) - check_and_update_generation_args(nlg_args) - NLGModel = getattr(models, nlg_args.model) - nlg_model, _ = NLGModel.load( - nlg_args.path, - model_checkpoint_file=nlg_args.checkpoint_name, - args=nlg_args, - device=device, - src_lang=nlg_args.src_locale, - tgt_lang=nlg_args.tgt_locale, - ) - nlg_model.to(device) - nlg_model.eval() - logger.info(f'Arguments:\n{pformat(vars(nlg_args))}') - logger.info(f'Loading from {nlg_args.best_checkpoint}') - else: - _, filename = args.nlg_type.split('-') - nlg_model = TemplateResponseGenerator(args.tgt_locale, filename=filename) - - return e2e_model, nlg_model - - -def main(args): - e2e_model, nlg_model = init(args) - loop = DialogueLoop(e2e_model, nlg_model) - loop.run() diff --git a/genienlp/validate.py b/genienlp/validate.py index caecd226..77cf542f 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -214,11 +214,6 @@ def generate_with_seq2seq_model_for_dialogue( input_text = replace_capturing_group(input_text, dataset.knowledge_re, new_knowledge_text) elif train_target == 'rg': - # replace state - # input_text = replace_capturing_group(contexts[-1], state_re, new_state_text) - - # replace knowledge - # input_text = replace_capturing_group(input_text, knowledge_re, new_knowledge_text) # replace actions input_text = replace_capturing_group(contexts[-1], dataset.actions_re, new_actions_text) @@ -295,7 +290,6 @@ def generate_with_seq2seq_model_for_dialogue( if do_api_call == 'yes': # make api call api_name = active_api - if api_name in dialogue_state: constraints, new_knowledge_text = dataset.make_api_call( dialogue_state, api_name, numericalizer._tokenizer.src_lang, dial_id, turn_id @@ -303,7 +297,6 @@ def generate_with_seq2seq_model_for_dialogue( #### save latest api constraints bitod_preds[dial_id]["API"][dataset.domain2api_name(api_name)] = copy.deepcopy(constraints) #### - elif do_api_call == 'no': # do nothing pass From 2ac2b474753120c37ba149f26d47ac7e4e2ac79a Mon Sep 17 00:00:00 2001 From: mehrad Date: Mon, 14 Feb 2022 16:39:25 -0800 Subject: [PATCH 070/102] Add DST task for bitod --- genienlp/tasks/generic_task.py | 41 +++++++++++++--------------------- 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 36ccd242..863e78d4 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -444,19 +444,21 @@ def utterance_field(self): return 'context' def _make_example(self, turn, **kwargs): - dial_id, turn_id, context, answer, train_target = ( + dial_id, turn_id, input_text, output_text, train_target = ( turn['dial_id'], turn['turn_id'], turn['input_text'], turn['output_text'], turn['train_target'], ) - question = '' + + if kwargs.get('train_target', False) and train_target != kwargs['train_target']: + return None example_id = '/'.join([dial_id, str(turn_id), train_target]) return Example.from_raw( - self.name + '/' + str(example_id), context, question, answer, preprocess=self.preprocess_field, lower=False + self.name + '/' + str(example_id), input_text, '', output_text, preprocess=self.preprocess_field, lower=False ) def get_splits(self, root, **kwargs): @@ -465,35 +467,22 @@ def get_splits(self, root, **kwargs): @register_task('bitod_nlg') -class BiTODNLG(BaseTask): +class BiTODNLG(BiTOD): def __init__(self, name, args): super().__init__(name, args) - self.special_tokens = {} self._metrics = ['casedbleu'] - def utterance_field(self): - return 'context' - - def _make_example(self, turn, **kwargs): - if 'response' not in turn: - return None - - dial_id, turn_id, answer, train_target, response = ( - turn['dial_id'], - turn['turn_id'], - turn['output_text'], - turn['train_target'], - turn['response'], - ) - question = '' - - assert train_target == 'response' + def get_splits(self, root, **kwargs): + kwargs['train_target'] = 'rg' + return BiTODDataset.return_splits(path=root, make_example=self._make_example, **kwargs) - example_id = '/'.join([dial_id, str(turn_id), train_target]) - return Example.from_raw( - self.name + '/' + str(example_id), answer, question, response, preprocess=self.preprocess_field, lower=False - ) +@register_task('bitod_dst') +class BiTODDST(BiTOD): + def __init__(self, name, args): + super().__init__(name, args) + self._metrics = ['em'] def get_splits(self, root, **kwargs): + kwargs['train_target'] = 'dst' return BiTODDataset.return_splits(path=root, make_example=self._make_example, **kwargs) From d0eaafcbf5aa378f7024df55a4a002452684960f Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 16 Feb 2022 16:30:02 -0800 Subject: [PATCH 071/102] Compute jga instead of em for DST --- genienlp/metrics.py | 30 +++++++++++++++++++++++++++++- genienlp/validate.py | 8 ++------ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index eaccbbef..6073f12b 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -39,6 +39,8 @@ import numpy as np import sacrebleu from datasets import load_metric +from dialogues import Bitod +from dialogues.bitod.src.evaluate import convert_lists_to_set from pyrouge import Rouge155 from seqeval import metrics as seq_metrics from seqeval import scheme as seq_scheme @@ -518,7 +520,7 @@ def computeBITOD(greedy, answer, tgt_lang): subtask_metrics_dict = defaultdict(tuple) subtasks = ['dst', 'api', 'da', 'rg'] - subtask_metrics = [['em'], ['em'], ['em'], ['casedbleu']] + subtask_metrics = [['jga'], ['em'], ['em'], ['casedbleu']] subtask_weights = [1, 1, 1, 1] for t in range(len(subtasks)): @@ -554,6 +556,28 @@ def computeBITOD(greedy, answer, tgt_lang): return bitod_score, JGA, API_em, DA_em, BLEU +def computeJGA(greedy, answer): + dataset = Bitod() + hit = 0 + greedy_state = defaultdict() + answer_state = defaultdict() + for g, a in zip(greedy, answer): + + dataset.update_state(a, answer_state) + dataset.update_state(g, greedy_state) + + answer_state = dataset.span2state(answer_state) + greedy_state = dataset.span2state(greedy_state) + + convert_lists_to_set(answer_state) + convert_lists_to_set(greedy_state) + + if answer_state == greedy_state: + hit += 1 + + return hit / len(greedy) + + def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): """ Inputs: @@ -578,6 +602,10 @@ def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): bitod_score, JGA, API_em, DA_em, BLEU = computeBITOD(greedy, answer, lang) metric_keys += ['bitod_score', 'JGA', 'API_em', 'DA_em', 'BLEU'] metric_values += [bitod_score, JGA, API_em, DA_em, BLEU] + if 'jga' in requested_metrics: + JGA = computeJGA(greedy, answer) + metric_keys += ['JGA'] + metric_values += [JGA] if 'lfem' in requested_metrics: lfem, answer = computeLFEM(greedy, answer) metric_keys += ['lfem'] diff --git a/genienlp/validate.py b/genienlp/validate.py index 77cf542f..d4b8f500 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -270,12 +270,8 @@ def generate_with_seq2seq_model_for_dialogue( # update dialogue_state lev = predictions[-1][0].strip() state_update = dataset.span2state(lev) - for api_name in state_update: - active_api = api_name - if api_name not in dialogue_state: - dialogue_state[api_name] = state_update[api_name] - else: - dialogue_state[api_name].update(state_update[api_name]) + active_api = list(state_update.keys())[-1] + dataset.update_state(state_update, dialogue_state) #### save latest state state_to_record = copy.deepcopy(dialogue_state) From 91650871c8d65e49407f9049999cad3ab8abb797 Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 16 Feb 2022 16:46:54 -0800 Subject: [PATCH 072/102] refactor computeBITOD --- genienlp/metrics.py | 50 ++++++++++++++++------------------ genienlp/tasks/generic_task.py | 2 +- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index 6073f12b..6bc656d7 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -519,41 +519,37 @@ def computeBITOD(greedy, answer, tgt_lang): num_examples = len(answer) subtask_metrics_dict = defaultdict(tuple) - subtasks = ['dst', 'api', 'da', 'rg'] - subtask_metrics = [['jga'], ['em'], ['em'], ['casedbleu']] - subtask_weights = [1, 1, 1, 1] + subtask2metrics = OrderedDict({'dst': 'jga', 'api': 'em', 'da': 'em', 'rg': 'casedbleu'}) + subtask2weights = OrderedDict({'dst': 1.0, 'api': 1.0, 'da': 1.0, 'rg': 1.0}) - for t in range(len(subtasks)): + results = OrderedDict({'bitod_score': 0.0, 'JGA': 0.0, 'API_em': 0.0, 'DA_em': 0.0, 'BLEU': 0.0}) + subtask2result_key = OrderedDict({'dst': 'JGA', 'api': 'API_em', 'da': 'DA_em', 'rg': 'BLEU'}) + + for k, task in enumerate(subtask2metrics): preds, golds = [], [] - for i in range(t, num_examples, len(subtasks)): + for i in range(k, num_examples, len(subtask2metrics)): preds.append(greedy[i]) golds.append(answer[i]) - metrics_to_compute = subtask_metrics[t] - sub_metrics, _ = compute_metrics(preds, golds, metrics_to_compute, tgt_lang) - subtask_metrics_dict[subtasks[t]] = (sub_metrics, len(golds), subtask_weights[t]) + metrics_to_compute = subtask2metrics[task] + sub_metrics, _ = compute_metrics(preds, golds, [metrics_to_compute], tgt_lang) + subtask_metrics_dict[task] = (sub_metrics, len(golds)) # TODO how should we aggregate? - bitod_score, JGA, API_em, DA_em, BLEU = 0.0, 0.0, 0.0, 0.0, 0.0 weighted_num_examples = 0 - for subtask, (sub_metrics, num_ex, weight) in subtask_metrics_dict.items(): - if subtask == 'dst': - bitod_score += weight * (sub_metrics['em'] * num_ex) - JGA = sub_metrics['em'] - elif subtask == 'api': - bitod_score += weight * (sub_metrics['em'] * num_ex) - API_em = sub_metrics['em'] - elif subtask == 'da': - bitod_score += weight * (sub_metrics['em'] * num_ex) - DA_em = sub_metrics['em'] - elif subtask == 'rg': - bitod_score += weight * (sub_metrics['casedbleu'] * num_ex) - BLEU = sub_metrics['casedbleu'] + for subtask, (sub_metrics, num_ex) in subtask_metrics_dict.items(): + metric = subtask2metrics[subtask] + weight = subtask2weights[subtask] + result_key = subtask2result_key[subtask] + + results[result_key] += sub_metrics[metric] + results['bitod_score'] += weight * (sub_metrics[metric] * num_ex) + weighted_num_examples += weight * num_ex - bitod_score /= weighted_num_examples + results['bitod_score'] /= weighted_num_examples - return bitod_score, JGA, API_em, DA_em, BLEU + return results def computeJGA(greedy, answer): @@ -599,9 +595,9 @@ def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): answer = [[a] for a in answer] if 'bitod_score' in requested_metrics: requested_metrics += ['JGA', 'API_em', 'DA_em', 'BLEU'] - bitod_score, JGA, API_em, DA_em, BLEU = computeBITOD(greedy, answer, lang) - metric_keys += ['bitod_score', 'JGA', 'API_em', 'DA_em', 'BLEU'] - metric_values += [bitod_score, JGA, API_em, DA_em, BLEU] + results = computeBITOD(greedy, answer, lang) + metric_keys += results.keys() + metric_values += results.values() if 'jga' in requested_metrics: JGA = computeJGA(greedy, answer) metric_keys += ['JGA'] diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 863e78d4..d9fad202 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -481,7 +481,7 @@ def get_splits(self, root, **kwargs): class BiTODDST(BiTOD): def __init__(self, name, args): super().__init__(name, args) - self._metrics = ['em'] + self._metrics = ['jga'] def get_splits(self, root, **kwargs): kwargs['train_target'] = 'dst' From 016de325222dbdcc8e0121cef169c3cd8f73bf54 Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 16 Feb 2022 19:49:45 -0800 Subject: [PATCH 073/102] BiTODDataset: pass train_target to make_example --- genienlp/tasks/generic_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/genienlp/tasks/generic_dataset.py b/genienlp/tasks/generic_dataset.py index a6857359..f77bd000 100644 --- a/genienlp/tasks/generic_dataset.py +++ b/genienlp/tasks/generic_dataset.py @@ -1979,7 +1979,7 @@ def __init__(self, path, *, make_example, **kwargs): with open(path) as fin: data = ujson.load(fin)['data'] for turn in data: - processed = make_example(turn) + processed = make_example(turn, train_target=kwargs['train_target']) if processed: examples.append(processed) From f88fe3cc2bbb98d4473b70b0d78d84f7b37313da Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 16 Feb 2022 17:49:32 -0800 Subject: [PATCH 074/102] metrics: fix computeJGA --- genienlp/metrics.py | 37 +++++++++++++++++++++---------------- genienlp/predict.py | 7 ++++++- genienlp/validate.py | 2 +- 3 files changed, 28 insertions(+), 18 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index 6bc656d7..226644e9 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -515,7 +515,7 @@ def computeDialogue(greedy, answer): return joint_goal_em, turn_request_em, turn_goal_em, answer -def computeBITOD(greedy, answer, tgt_lang): +def computeBITOD(greedy, answer, tgt_lang, example_ids): num_examples = len(answer) subtask_metrics_dict = defaultdict(tuple) @@ -532,7 +532,7 @@ def computeBITOD(greedy, answer, tgt_lang): golds.append(answer[i]) metrics_to_compute = subtask2metrics[task] - sub_metrics, _ = compute_metrics(preds, golds, [metrics_to_compute], tgt_lang) + sub_metrics, _ = compute_metrics(preds, golds, [metrics_to_compute], tgt_lang, example_ids) subtask_metrics_dict[task] = (sub_metrics, len(golds)) # TODO how should we aggregate? @@ -552,29 +552,34 @@ def computeBITOD(greedy, answer, tgt_lang): return results -def computeJGA(greedy, answer): +def computeJGA(greedy, answer, example_ids): dataset = Bitod() hit = 0 - greedy_state = defaultdict() - answer_state = defaultdict() - for g, a in zip(greedy, answer): + cur_dial_id = None + for id_, g, a in zip(example_ids, greedy, answer): + dial_id = id_.split('/')[1] + if dial_id != cur_dial_id: + cur_dial_id = dial_id + greedy_state = defaultdict() + answer_state = defaultdict() + + a = a[0] + a = dataset.span2state(a) + g = dataset.span2state(g) dataset.update_state(a, answer_state) dataset.update_state(g, greedy_state) - answer_state = dataset.span2state(answer_state) - greedy_state = dataset.span2state(greedy_state) - convert_lists_to_set(answer_state) convert_lists_to_set(greedy_state) if answer_state == greedy_state: hit += 1 - return hit / len(greedy) + return hit / len(greedy) * 100 -def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): +def compute_metrics(greedy, answer, requested_metrics: Iterable, lang, example_ids): """ Inputs: requested_metrics: contains a subset of the following metrics @@ -595,12 +600,12 @@ def compute_metrics(greedy, answer, requested_metrics: Iterable, lang): answer = [[a] for a in answer] if 'bitod_score' in requested_metrics: requested_metrics += ['JGA', 'API_em', 'DA_em', 'BLEU'] - results = computeBITOD(greedy, answer, lang) + results = computeBITOD(greedy, answer, lang, example_ids) metric_keys += results.keys() metric_values += results.values() if 'jga' in requested_metrics: - JGA = computeJGA(greedy, answer) - metric_keys += ['JGA'] + JGA = computeJGA(greedy, answer, example_ids) + metric_keys += ['jga'] metric_values += [JGA] if 'lfem' in requested_metrics: lfem, answer = computeLFEM(greedy, answer) @@ -719,10 +724,10 @@ def convert_IOB2_to_IOB1(labels): return metric_dict, answer -def calculate_and_reduce_metrics(predictions, answers, metrics_to_compute, reduce_metrics, lang): +def calculate_and_reduce_metrics(predictions, answers, metrics_to_compute, reduce_metrics, lang, example_ids): metrics = OrderedDict() for i in range(len(predictions[0])): - partial_metrics, _ = compute_metrics([p[i] for p in predictions], answers, metrics_to_compute, lang) + partial_metrics, _ = compute_metrics([p[i] for p in predictions], answers, metrics_to_compute, lang, example_ids) for k, v in partial_metrics.items(): if reduce_metrics == 'max': metrics[k] = max(metrics.get(k, 0), v) diff --git a/genienlp/predict.py b/genienlp/predict.py index e44a2c1f..e700a076 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -502,7 +502,12 @@ def run(args, device): if args.main_metric_only: metrics_to_compute = [metrics_to_compute[0]] metrics = calculate_and_reduce_metrics( - generation_output.predictions, generation_output.answers, metrics_to_compute, args.reduce_metrics, tgt_lang + generation_output.predictions, + generation_output.answers, + metrics_to_compute, + args.reduce_metrics, + tgt_lang, + generation_output.example_ids, ) with open(results_file_name, 'w' + ('' if args.overwrite else '+')) as results_file: diff --git a/genienlp/validate.py b/genienlp/validate.py index d4b8f500..a28fc1cf 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -672,7 +672,7 @@ def validate(task, val_iter, model, numericalizer, args, num_print=10): metrics_to_return = [metric for metric in task.metrics if metric != 'loss'] metrics = calculate_and_reduce_metrics( - output.predictions, output.answers, metrics_to_return, args.reduce_metrics, model.tgt_lang + output.predictions, output.answers, metrics_to_return, args.reduce_metrics, model.tgt_lang, output.example_ids ) results = {'beam search': output.predictions, 'answer': output.answers, 'context': output.contexts} From 5bc6a65dad76ae7cf210c4e5693e8b632af544fd Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 17 Feb 2022 16:26:22 -0800 Subject: [PATCH 075/102] pass kwargs to other bitod tasks --- genienlp/tasks/generic_task.py | 2 ++ genienlp/util.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index d9fad202..30fcd4b5 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -474,6 +474,7 @@ def __init__(self, name, args): def get_splits(self, root, **kwargs): kwargs['train_target'] = 'rg' + kwargs['e2e_evaluation'] = self.args.bitod_e2e_evaluation return BiTODDataset.return_splits(path=root, make_example=self._make_example, **kwargs) @@ -485,4 +486,5 @@ def __init__(self, name, args): def get_splits(self, root, **kwargs): kwargs['train_target'] = 'dst' + kwargs['e2e_evaluation'] = self.args.bitod_e2e_evaluation return BiTODDataset.return_splits(path=root, make_example=self._make_example, **kwargs) diff --git a/genienlp/util.py b/genienlp/util.py index da7449ff..a1d1d482 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -618,7 +618,7 @@ def make_data_loader(dataset, numericalizer, batch_size, device=None, train=Fals if min_batch_length > batch_size: raise ValueError( - f'The minimum example length in your dataset is {np.min(context_lengths) + np.min(answer_lengths)} but your batch size is {batch_size}.' + f'The minimum batch length in your dataset is {min_batch_length} but your batch size is {batch_size}.' f' Thus no examples will be processed. Consider increasing batch_size' ) if np.min(answer_lengths) < min_output_length: From f653f35c17d2dda3749352fb5f5d696419ebadc2 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 17 Feb 2022 16:55:09 -0800 Subject: [PATCH 076/102] validate: fix bug when state_update is null --- genienlp/validate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index a28fc1cf..7c3eb8aa 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -270,7 +270,8 @@ def generate_with_seq2seq_model_for_dialogue( # update dialogue_state lev = predictions[-1][0].strip() state_update = dataset.span2state(lev) - active_api = list(state_update.keys())[-1] + if state_update: + active_api = list(state_update.keys())[-1] dataset.update_state(state_update, dialogue_state) #### save latest state From fef7611472d9af86fc1271613284cfcbe57cdaf4 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 17 Feb 2022 17:23:30 -0800 Subject: [PATCH 077/102] mtrics: check examples based on ids in computeBITOD --- genienlp/metrics.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index 226644e9..d1805b40 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -527,13 +527,16 @@ def computeBITOD(greedy, answer, tgt_lang, example_ids): for k, task in enumerate(subtask2metrics): preds, golds = [], [] - for i in range(k, num_examples, len(subtask2metrics)): - preds.append(greedy[i]) - golds.append(answer[i]) + for i in range(num_examples): + id_ = example_ids[i] + if id_.endswith(f'/{task}'): + preds.append(greedy[i]) + golds.append(answer[i]) - metrics_to_compute = subtask2metrics[task] - sub_metrics, _ = compute_metrics(preds, golds, [metrics_to_compute], tgt_lang, example_ids) - subtask_metrics_dict[task] = (sub_metrics, len(golds)) + if golds: + metrics_to_compute = subtask2metrics[task] + sub_metrics, _ = compute_metrics(preds, golds, [metrics_to_compute], tgt_lang, example_ids) + subtask_metrics_dict[task] = (sub_metrics, len(golds)) # TODO how should we aggregate? weighted_num_examples = 0 @@ -579,7 +582,7 @@ def computeJGA(greedy, answer, example_ids): return hit / len(greedy) * 100 -def compute_metrics(greedy, answer, requested_metrics: Iterable, lang, example_ids): +def compute_metrics(greedy, answer, requested_metrics: Iterable, lang, example_ids=None): """ Inputs: requested_metrics: contains a subset of the following metrics From 187d05a381bc21da3f092df209ecf4e03b5c56be Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 17 Feb 2022 20:22:50 -0800 Subject: [PATCH 078/102] Use em as second metric for bitod_dst task --- genienlp/tasks/generic_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 30fcd4b5..386a7496 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -482,7 +482,7 @@ def get_splits(self, root, **kwargs): class BiTODDST(BiTOD): def __init__(self, name, args): super().__init__(name, args) - self._metrics = ['jga'] + self._metrics = ['jga', 'em'] def get_splits(self, root, **kwargs): kwargs['train_target'] = 'dst' From 6e8142c7f8790e038bf875be334f5dd0c763f471 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 18 Feb 2022 00:28:12 -0800 Subject: [PATCH 079/102] Add argument to specify train_set_name --- genienlp/arguments.py | 1 + genienlp/train.py | 1 + 2 files changed, 2 insertions(+) diff --git a/genienlp/arguments.py b/genienlp/arguments.py index efd143c8..dc8447a9 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -183,6 +183,7 @@ def parse_argv(parser): default=0.1, help='multiplicative constant choosing the weight of encoder_loss in total loss', ) + parser.add_argument('--train_set_name', type=str, default='train', help='Training dataset name to use during training') parser.add_argument('--eval_set_name', type=str, help='Evaluation dataset name to use during training') parser.add_argument('--max_output_length', default=150, type=int, help='maximum output length for generation') diff --git a/genienlp/train.py b/genienlp/train.py index 7b8f283c..3f23b767 100644 --- a/genienlp/train.py +++ b/genienlp/train.py @@ -96,6 +96,7 @@ def prepare_data(args, logger): for task in args.train_tasks: logger.info(f'Loading {task.name}') kwargs = {'test': None, 'validation': None} + kwargs['train'] = args.train_set_name kwargs.update(train_eval_shared_kwargs) kwargs['all_dirs'] = args.train_src_languages kwargs['cached_path'] = os.path.join(args.cache, task.name) From 9d8c7036f3428c615648387dc3fe4b011f78a835 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 18 Feb 2022 09:32:46 -0800 Subject: [PATCH 080/102] predict: output contexts too --- genienlp/predict.py | 26 ++++++++++++++------------ genienlp/validate.py | 3 +++ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/genienlp/predict.py b/genienlp/predict.py index e700a076..f6bbfba5 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -471,12 +471,13 @@ def run(args, device): # TODO change to jsonl format with open(prediction_file_name, 'w' + ('' if args.overwrite else '+')) as prediction_file: for i in range(len(generation_output.example_ids)): - line = ( - generation_output.example_ids[i] - + '\t' - + '\t'.join(generation_output.predictions[i]) - + '\t' - + generation_output.answers[i] + line = '\t'.join( + [ + generation_output.example_ids[i], + '\t'.join(generation_output.predictions[i]), + generation_output.answers[i], + generation_output.contexts[i], + ] ) # all outputs separated by '\t' if args.calibrator_paths is not None: for score in generation_output.confidence_scores: @@ -486,12 +487,13 @@ def run(args, device): if args.translate_return_raw_outputs: with open(raw_prediction_file_name, 'w' + ('' if args.overwrite else '+')) as prediction_file: for i in range(len(generation_output.example_ids)): - line = ( - generation_output.example_ids[i] - + '\t' - + '\t'.join(generation_output.raw_predictions[i]) - + '\t' - + generation_output.answers[i] + line = '\t'.join( + [ + generation_output.example_ids[i], + '\t'.join(generation_output.raw_predictions[i]), + generation_output.answers[i], + generation_output.contexts[i], + ] ) # all outputs separated by '\t' prediction_file.write(line + '\n') diff --git a/genienlp/validate.py b/genienlp/validate.py index 7c3eb8aa..3a63d95c 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -221,6 +221,9 @@ def generate_with_seq2seq_model_for_dialogue( else: raise ValueError(f'Invalid train_target: {train_target}') + # replace old context with updated + contexts[-1] = input_text + tokenized_contexts = numericalizer.encode_batch([input_text], field_name='context', features=None)[0] numericalized_turn = NumericalizedExamples( From 638df872d8be13ebb54d8e6f02e5c80386de6267 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 18 Feb 2022 10:27:50 -0800 Subject: [PATCH 081/102] Fix bug in validation for bitod_dst --- genienlp/validate.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index 3a63d95c..fe110ece 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -161,7 +161,7 @@ def generate_with_seq2seq_model_for_dialogue( # new dialogue cur_dial_id = dial_id dialogue_state = {} - new_state_text = 'null' + # new_state_text = 'null' new_knowledge_text = 'null' new_actions_text = 'null' active_api = None @@ -192,6 +192,8 @@ def generate_with_seq2seq_model_for_dialogue( ] answers += batch_answer + new_state_text = dataset.state2span(dialogue_state) + if train_target == 'dst': input_text = replace_capturing_group(contexts[-1], dataset.state_re, new_state_text) @@ -201,7 +203,6 @@ def generate_with_seq2seq_model_for_dialogue( # input_text = replace_match(input_text, last_system_re, last_sys_pred) elif train_target == 'api': - new_state_text = dataset.state2span(dialogue_state) # replace state input_text = replace_capturing_group(contexts[-1], dataset.state_re, new_state_text) From e356976e5fc26155ac871846f2895b7f37e2e181 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 18 Feb 2022 10:45:56 -0800 Subject: [PATCH 082/102] Add bitod_valid_subtasks argument + cleanup metrics --- genienlp/arguments.py | 9 ++++++--- genienlp/metrics.py | 25 +++++++++++++++++-------- genienlp/predict.py | 14 +++++++------- genienlp/tasks/generic_dataset.py | 2 +- genienlp/util.py | 7 ++++++- genienlp/validate.py | 4 +--- 6 files changed, 38 insertions(+), 23 deletions(-) diff --git a/genienlp/arguments.py b/genienlp/arguments.py index dc8447a9..3aa79546 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -543,10 +543,13 @@ def parse_argv(parser): help='Debugging flag for hf datasets where validation will be performed on train set', ) + parser.add_argument('--bitod_e2e_evaluation', action='store_true', help='Evaluate model end-to-end') parser.add_argument( - '--bitod_e2e_evaluation', - action='store_true', - help='', + '--bitod_valid_subtasks', + nargs='+', + type=str, + default=['dst', 'api', 'da'], + help='Evaluate only on these subtasks when calculating bitod_score; rg is not included by default', ) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index d1805b40..17ba5a3b 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -34,7 +34,6 @@ from contextlib import closing from multiprocessing import Pool, cpu_count from subprocess import PIPE, Popen -from typing import Iterable import numpy as np import sacrebleu @@ -515,7 +514,7 @@ def computeDialogue(greedy, answer): return joint_goal_em, turn_request_em, turn_goal_em, answer -def computeBITOD(greedy, answer, tgt_lang, example_ids): +def computeBITOD(greedy, answer, tgt_lang, args, example_ids): num_examples = len(answer) subtask_metrics_dict = defaultdict(tuple) @@ -526,6 +525,8 @@ def computeBITOD(greedy, answer, tgt_lang, example_ids): subtask2result_key = OrderedDict({'dst': 'JGA', 'api': 'API_em', 'da': 'DA_em', 'rg': 'BLEU'}) for k, task in enumerate(subtask2metrics): + if task not in args.bitod_valid_subtasks: + continue preds, golds = [], [] for i in range(num_examples): id_ = example_ids[i] @@ -535,7 +536,7 @@ def computeBITOD(greedy, answer, tgt_lang, example_ids): if golds: metrics_to_compute = subtask2metrics[task] - sub_metrics, _ = compute_metrics(preds, golds, [metrics_to_compute], tgt_lang, example_ids) + sub_metrics, _ = compute_metrics(preds, golds, [metrics_to_compute], tgt_lang, args, example_ids) subtask_metrics_dict[task] = (sub_metrics, len(golds)) # TODO how should we aggregate? @@ -582,7 +583,7 @@ def computeJGA(greedy, answer, example_ids): return hit / len(greedy) * 100 -def compute_metrics(greedy, answer, requested_metrics: Iterable, lang, example_ids=None): +def compute_metrics(greedy, answer, requested_metrics, lang, args, example_ids=None): """ Inputs: requested_metrics: contains a subset of the following metrics @@ -603,7 +604,7 @@ def compute_metrics(greedy, answer, requested_metrics: Iterable, lang, example_i answer = [[a] for a in answer] if 'bitod_score' in requested_metrics: requested_metrics += ['JGA', 'API_em', 'DA_em', 'BLEU'] - results = computeBITOD(greedy, answer, lang, example_ids) + results = computeBITOD(greedy, answer, lang, args, example_ids) metric_keys += results.keys() metric_values += results.values() if 'jga' in requested_metrics: @@ -727,12 +728,20 @@ def convert_IOB2_to_IOB1(labels): return metric_dict, answer -def calculate_and_reduce_metrics(predictions, answers, metrics_to_compute, reduce_metrics, lang, example_ids): +def calculate_and_reduce_metrics(generation_output, metrics_to_compute, args, lang): metrics = OrderedDict() + predictions = generation_output.predictions for i in range(len(predictions[0])): - partial_metrics, _ = compute_metrics([p[i] for p in predictions], answers, metrics_to_compute, lang, example_ids) + partial_metrics, _ = compute_metrics( + [p[i] for p in predictions], + generation_output.answers, + metrics_to_compute, + lang, + args, + generation_output.example_ids, + ) for k, v in partial_metrics.items(): - if reduce_metrics == 'max': + if args.reduce_metrics == 'max': metrics[k] = max(metrics.get(k, 0), v) else: raise ValueError('Invalid reduce_metrics argument') diff --git a/genienlp/predict.py b/genienlp/predict.py index f6bbfba5..cf611542 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -246,10 +246,12 @@ def parse_argv(parser): help='do not preserve quotation marks in the output. Useful if using alignment for semantic parsing or NLG', ) + parser.add_argument('--bitod_e2e_evaluation', action='store_true', help='Evaluate model end-to-end') parser.add_argument( - '--bitod_e2e_evaluation', - action='store_true', - help='', + '--bitod_valid_subtasks', + nargs='+', + type=str, + help='Evaluate only on these subtasks when calculating bitod_score; rg is not included by default', ) @@ -504,12 +506,10 @@ def run(args, device): if args.main_metric_only: metrics_to_compute = [metrics_to_compute[0]] metrics = calculate_and_reduce_metrics( - generation_output.predictions, - generation_output.answers, + generation_output, metrics_to_compute, - args.reduce_metrics, + args, tgt_lang, - generation_output.example_ids, ) with open(results_file_name, 'w' + ('' if args.overwrite else '+')) as results_file: diff --git a/genienlp/tasks/generic_dataset.py b/genienlp/tasks/generic_dataset.py index f77bd000..b92a5844 100644 --- a/genienlp/tasks/generic_dataset.py +++ b/genienlp/tasks/generic_dataset.py @@ -1979,7 +1979,7 @@ def __init__(self, path, *, make_example, **kwargs): with open(path) as fin: data = ujson.load(fin)['data'] for turn in data: - processed = make_example(turn, train_target=kwargs['train_target']) + processed = make_example(turn, train_target=kwargs.get('train_target', False)) if processed: examples.append(processed) diff --git a/genienlp/util.py b/genienlp/util.py index a1d1d482..4c1aaca7 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -868,7 +868,8 @@ def load_config_json(args): 'eval_tgt_languages', ] - # train and predict scripts have these arguments in common. We use the values from train only if they are not provided in predict + # train and predict scripts have these arguments in common. We use the values from train only if they are not provided in predict. + # NOTE: do not set default values for these arguments in predict cause the defaults will always override training arguments overwrite = [ 'val_batch_size', 'num_beams', @@ -884,6 +885,7 @@ def load_config_json(args): 'min_output_length', 'reduce_metrics', 'database_dir', + 'bitod_valid_subtasks', ] for o in overwrite: if o not in args or getattr(args, o) is None: @@ -960,6 +962,9 @@ def load_config_json(args): # use default value setattr(args, r, None) + if args.bitod_valid_subtasks is None: + setattr(args, 'bitod_valid_subtasks', ['dst', 'api', 'da']) + # backward compatibility for models trained with genienlp before NED Refactoring (2) if args.max_features_size is None: if hasattr(args, 'ned_features_size'): diff --git a/genienlp/validate.py b/genienlp/validate.py index fe110ece..0c5bd768 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -676,9 +676,7 @@ def validate(task, val_iter, model, numericalizer, args, num_print=10): # loss is already calculated metrics_to_return = [metric for metric in task.metrics if metric != 'loss'] - metrics = calculate_and_reduce_metrics( - output.predictions, output.answers, metrics_to_return, args.reduce_metrics, model.tgt_lang, output.example_ids - ) + metrics = calculate_and_reduce_metrics(output, metrics_to_return, args, model.tgt_lang) results = {'beam search': output.predictions, 'answer': output.answers, 'context': output.contexts} From 68ddc3e8ac0c2cb193451a7acdf35e8f03fca4bf Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 18 Feb 2022 12:38:54 -0800 Subject: [PATCH 083/102] path add argument to specify train_set_name --- genienlp/tasks/generic_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/genienlp/tasks/generic_dataset.py b/genienlp/tasks/generic_dataset.py index b92a5844..0adb3784 100644 --- a/genienlp/tasks/generic_dataset.py +++ b/genienlp/tasks/generic_dataset.py @@ -1999,7 +1999,7 @@ def __init__(self, path, *, make_example, **kwargs): def return_splits(cls, path='.data', train='train', validation='valid', test='test', **kwargs): train_path, validation_path, test_path = None, None, None if train: - train_path = os.path.join(path, 'train.json') + train_path = os.path.join(path, f'{train}.json') if validation: validation_path = os.path.join(path, f'{validation}.json') if test: From db5e03140e04f4636273dc165e3e5d5c295eada1 Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 22 Feb 2022 18:39:13 -0800 Subject: [PATCH 084/102] Add bitod_valid_submetrics and bitod_valid_subweights args --- genienlp/arguments.py | 24 ++++++++++++++++++++++++ genienlp/metrics.py | 23 ++++++++--------------- genienlp/predict.py | 12 ++++++++++++ genienlp/util.py | 5 +++++ 4 files changed, 49 insertions(+), 15 deletions(-) diff --git a/genienlp/arguments.py b/genienlp/arguments.py index 3aa79546..41a9caee 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -551,6 +551,20 @@ def parse_argv(parser): default=['dst', 'api', 'da'], help='Evaluate only on these subtasks when calculating bitod_score; rg is not included by default', ) + parser.add_argument( + '--bitod_valid_submetrics', + nargs='+', + type=str, + default=['jga', 'em', 'em'], + help='Specify metrics to use for each of subtasks in bitod_valid_subtasks.', + ) + parser.add_argument( + '--bitod_valid_subweights', + nargs='+', + type=float, + default=[1.0, 1.0, 1.0], + help='Specify weights to use for each of subtasks in bitod_valid_subtasks.', + ) def check_and_update_generation_args(args): @@ -644,6 +658,16 @@ def post_parse_train_specific(args): logger.warning('When evaluating bitod end2end val_batch_size should be 1 so we load data turn by turn') args.val_batch_size = [1] + if len(args.bitod_valid_subtasks) != len(args.bitod_valid_submetrics): + raise ValueError( + 'Length of bitod_valid_subtasks and bitod_valid_submetrics arguments should be equal (i.e. one metric per subtask)' + ) + + if len(args.bitod_valid_subtasks) != len(args.bitod_valid_subweights): + raise ValueError( + 'Length of bitod_valid_subtasks and bitod_valid_subweights arguments should be equal (i.e. one weight per subtask)' + ) + if len(args.val_batch_size) < len(args.val_task_names): args.val_batch_size = len(args.val_task_names) * args.val_batch_size diff --git a/genienlp/metrics.py b/genienlp/metrics.py index 17ba5a3b..78e863d7 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -518,37 +518,30 @@ def computeBITOD(greedy, answer, tgt_lang, args, example_ids): num_examples = len(answer) subtask_metrics_dict = defaultdict(tuple) - subtask2metrics = OrderedDict({'dst': 'jga', 'api': 'em', 'da': 'em', 'rg': 'casedbleu'}) - subtask2weights = OrderedDict({'dst': 1.0, 'api': 1.0, 'da': 1.0, 'rg': 1.0}) - results = OrderedDict({'bitod_score': 0.0, 'JGA': 0.0, 'API_em': 0.0, 'DA_em': 0.0, 'BLEU': 0.0}) subtask2result_key = OrderedDict({'dst': 'JGA', 'api': 'API_em', 'da': 'DA_em', 'rg': 'BLEU'}) - for k, task in enumerate(subtask2metrics): - if task not in args.bitod_valid_subtasks: - continue + for k, subtask in enumerate(args.bitod_valid_subtasks): preds, golds = [], [] for i in range(num_examples): id_ = example_ids[i] - if id_.endswith(f'/{task}'): + if id_.endswith(f'/{subtask}'): preds.append(greedy[i]) golds.append(answer[i]) if golds: - metrics_to_compute = subtask2metrics[task] + metrics_to_compute = args.bitod_valid_submetrics[subtask] sub_metrics, _ = compute_metrics(preds, golds, [metrics_to_compute], tgt_lang, args, example_ids) - subtask_metrics_dict[task] = (sub_metrics, len(golds)) + subtask_metrics_dict[subtask] = (sub_metrics[metrics_to_compute], len(golds)) # TODO how should we aggregate? weighted_num_examples = 0 - for subtask, (sub_metrics, num_ex) in subtask_metrics_dict.items(): - metric = subtask2metrics[subtask] - weight = subtask2weights[subtask] + for subtask, (sub_result, num_ex) in subtask_metrics_dict.items(): + weight = args.bitod_valid_subweights[subtask] result_key = subtask2result_key[subtask] - results[result_key] += sub_metrics[metric] - results['bitod_score'] += weight * (sub_metrics[metric] * num_ex) - + results[result_key] += sub_result + results['bitod_score'] += weight * (sub_result * num_ex) weighted_num_examples += weight * num_ex results['bitod_score'] /= weighted_num_examples diff --git a/genienlp/predict.py b/genienlp/predict.py index cf611542..c185f864 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -253,6 +253,18 @@ def parse_argv(parser): type=str, help='Evaluate only on these subtasks when calculating bitod_score; rg is not included by default', ) + parser.add_argument( + '--bitod_valid_submetrics', + nargs='+', + type=str, + help='Specify metrics to use for each of subtasks in bitod_valid_subtasks.', + ) + parser.add_argument( + '--bitod_valid_subweights', + nargs='+', + type=float, + help='Specify weights to use for each of subtasks in bitod_valid_subtasks.', + ) def set_default_values(args): diff --git a/genienlp/util.py b/genienlp/util.py index 4c1aaca7..0bca6fbe 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -886,6 +886,7 @@ def load_config_json(args): 'reduce_metrics', 'database_dir', 'bitod_valid_subtasks', + 'bitod_valid_submetrics' 'bitod_valid_subweights', ] for o in overwrite: if o not in args or getattr(args, o) is None: @@ -964,6 +965,10 @@ def load_config_json(args): if args.bitod_valid_subtasks is None: setattr(args, 'bitod_valid_subtasks', ['dst', 'api', 'da']) + if args.bitod_valid_submetrics is None: + setattr(args, 'bitod_valid_submetrics', ['jga', 'em', 'em']) + if args.bitod_valid_subweights is None: + setattr(args, 'bitod_valid_subweights', [1.0, 1.0, 1.0]) # backward compatibility for models trained with genienlp before NED Refactoring (2) if args.max_features_size is None: From 99845f01fbe359a43a972592d5fbe9fe50f687b2 Mon Sep 17 00:00:00 2001 From: mehrad Date: Tue, 22 Feb 2022 13:12:59 -0800 Subject: [PATCH 085/102] validate: carry api results for metro --- genienlp/validate.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index 0c5bd768..55a1a9ad 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -162,6 +162,7 @@ def generate_with_seq2seq_model_for_dialogue( cur_dial_id = dial_id dialogue_state = {} # new_state_text = 'null' + knowledge = defaultdict(dict) new_knowledge_text = 'null' new_actions_text = 'null' active_api = None @@ -285,7 +286,10 @@ def generate_with_seq2seq_model_for_dialogue( #### elif train_target == 'api': - new_knowledge_text = 'null' + if 'HKMTR' not in active_api: + new_knowledge_text = "null" + knowledge = defaultdict(dict) + do_api_call = predictions[-1][0].strip() if do_api_call == 'yes': @@ -293,11 +297,12 @@ def generate_with_seq2seq_model_for_dialogue( api_name = active_api if api_name in dialogue_state: constraints, new_knowledge_text = dataset.make_api_call( - dialogue_state, api_name, numericalizer._tokenizer.src_lang, dial_id, turn_id + dialogue_state, knowledge, api_name, numericalizer._tokenizer.src_lang, dial_id, turn_id ) #### save latest api constraints bitod_preds[dial_id]["API"][dataset.domain2api_name(api_name)] = copy.deepcopy(constraints) #### + elif do_api_call == 'no': # do nothing pass From 10a924c7ef757c6926a6caf744486452a7d49401 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 24 Feb 2022 11:27:23 -0800 Subject: [PATCH 086/102] Add option to evaluate on prediction file instead of a model --- genienlp/predict.py | 109 ++++++++++++++++++++++++++++++-------------- genienlp/util.py | 2 +- 2 files changed, 77 insertions(+), 34 deletions(-) diff --git a/genienlp/predict.py b/genienlp/predict.py index c185f864..009a77cf 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -43,6 +43,8 @@ except RuntimeError: pass +import sys + import torch from . import models @@ -61,17 +63,17 @@ set_seed, split_folder_on_disk, ) -from .validate import generate_with_model +from .validate import GenerationOutput, generate_with_model logger = logging.getLogger(__name__) def parse_argv(parser): - parser.add_argument('--path', type=str, required=True, help='Folder to load the model from') + parser.add_argument('--path', type=str, required='--pred_file' not in sys.argv, help='Folder to load the model from') parser.add_argument( '--evaluate', type=str, - required=True, + required='--pred_file' not in sys.argv, choices=['train', 'valid', 'test'], help='Which dataset to do predictions for (train, dev or test)', ) @@ -104,6 +106,12 @@ def parse_argv(parser): parser.add_argument('--cache', default='.cache', type=str, help='where to save cached files') parser.add_argument('--subsample', default=20000000, type=int, help='subsample the eval/test datasets') + parser.add_argument( + '--pred_file', + type=str, + help='If provided, we just compute evaluation metrics on it and bypass model prediction. File should be in tsv format with id, pred, target columns', + ) + parser.add_argument( '--pred_languages', type=str, @@ -428,8 +436,6 @@ def run(args, device): log_model_size(logger, model, args.model) model.to(device) - decaScore = [] - task_scores = defaultdict(list) model.eval() eval_dir = os.path.join(args.eval_dir, args.evaluate) @@ -512,37 +518,13 @@ def run(args, device): prediction_file.write(line + '\n') if len(generation_output.answers) > 0: - metrics_to_compute = task.metrics - metrics_to_compute += args.extra_metrics - metrics_to_compute = [metric for metric in task.metrics if metric not in ['loss']] - if args.main_metric_only: - metrics_to_compute = [metrics_to_compute[0]] - metrics = calculate_and_reduce_metrics( - generation_output, - metrics_to_compute, - args, - tgt_lang, - ) - - with open(results_file_name, 'w' + ('' if args.overwrite else '+')) as results_file: - results_file.write(json.dumps(metrics) + '\n') - - if not args.silent: - for i, (c, p, a) in enumerate( - zip(generation_output.contexts, generation_output.predictions, generation_output.answers) - ): - log_string = f'\nContext {i+1}: {c}\nPrediction {i + 1} ({len(p)} outputs): {p}\nAnswer {i + 1}: {a}\n' - if args.calibrator_paths is not None: - log_string += f'Confidence {i+1} : ' - for score in generation_output.confidence_scores: - log_string += f'{score[i]:.3f}, ' - log_string += '\n' - logger.info(log_string) + task_scores = compute_metrics_on_file(prediction_file_name, results_file_name, task, args, tgt_lang) - logger.info(metrics) + log_final_results(args, task_scores) - task_scores[task].append((len(generation_output.answers), metrics[task.metrics[0]])) +def log_final_results(args, task_scores): + decaScore = [] for task in task_scores.keys(): decaScore.append( sum([length * score for length, score in task_scores[task]]) / sum([length for length, score in task_scores[task]]) @@ -556,6 +538,56 @@ def run(args, device): logger.info(f'\nSummary: | {sum(decaScore)} | {" | ".join([str(x) for x in decaScore])} |\n') +def compute_metrics_on_file(task_scores, pred_file, results_file_name, task, args, tgt_lang): + generation_output = GenerationOutput() + ids, contexts, preds, targets, confidence_scores = [], [], [], [], [] + with open(pred_file) as fin: + for line in fin: + id_, pred, target, context, *conf_scores = line.strip('\n').split('\t') + ids.append(id_) + contexts.append(context) + preds.append(pred) + targets.append(target) + confidence_scores.append(conf_scores) + + generation_output.example_ids = ids + generation_output.contexts = contexts + generation_output.predictions = preds + generation_output.answers = targets + generation_output.confidence_scores = confidence_scores + + metrics_to_compute = task.metrics + metrics_to_compute += args.extra_metrics + metrics_to_compute = [metric for metric in task.metrics if metric not in ['loss']] + if args.main_metric_only: + metrics_to_compute = [metrics_to_compute[0]] + metrics = calculate_and_reduce_metrics( + generation_output, + metrics_to_compute, + args, + tgt_lang, + ) + + with open(results_file_name, 'w' + ('' if args.overwrite else '+')) as results_file: + results_file.write(json.dumps(metrics) + '\n') + + if not args.silent: + for i, (c, p, a) in enumerate( + zip(generation_output.contexts, generation_output.predictions, generation_output.answers) + ): + log_string = f'\nContext {i + 1}: {c}\nPrediction {i + 1} ({len(p)} outputs): {p}\nAnswer {i + 1}: {a}\n' + if args.calibrator_paths is not None: + log_string += f'Confidence {i + 1} : ' + for score in generation_output.confidence_scores: + log_string += f'{score[i]:.3f}, ' + log_string += '\n' + logger.info(log_string) + + logger.info(metrics) + + task_scores[task].append((len(generation_output.answers), metrics[task.metrics[0]])) + + def main(args): load_config_json(args) check_and_update_generation_args(args) @@ -585,6 +617,17 @@ def main(args): task.metrics = new_metrics + if args.pred_file and os.path.exists(args.pred_file): + task_scores = defaultdict(list) + eval_dir = os.path.join(args.eval_dir, args.evaluate) + os.makedirs(eval_dir, exist_ok=True) + tgt_lang = args.pred_tgt_languages[0] + for task in args.tasks: + results_file_name = os.path.join(eval_dir, task.name + '.results.json') + compute_metrics_on_file(task_scores, args.pred_file, results_file_name, task, args, tgt_lang) + log_final_results(args, task_scores) + return + if len(devices) > 1: logger.info(f'Independent multi-GPU generation on following devices: {devices}') all_processes = [] diff --git a/genienlp/util.py b/genienlp/util.py index 0bca6fbe..9f967b1f 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -232,7 +232,7 @@ def __repr__(self) -> str: ) -class GenerationOutput: +class GenerationOutput(object): """ Contains all the information that the generation function may need to output """ From ebf46f17c8fe98fe5b1d7d40edd734f8f82ad5da Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 24 Feb 2022 11:48:33 -0800 Subject: [PATCH 087/102] Fix metrics --- genienlp/metrics.py | 19 ++++++++++++------- genienlp/predict.py | 24 ++++++++++++++++-------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index 78e863d7..a7199574 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -516,28 +516,28 @@ def computeDialogue(greedy, answer): def computeBITOD(greedy, answer, tgt_lang, args, example_ids): num_examples = len(answer) - subtask_metrics_dict = defaultdict(tuple) + subtask_metrics_dict = OrderedDict() results = OrderedDict({'bitod_score': 0.0, 'JGA': 0.0, 'API_em': 0.0, 'DA_em': 0.0, 'BLEU': 0.0}) subtask2result_key = OrderedDict({'dst': 'JGA', 'api': 'API_em', 'da': 'DA_em', 'rg': 'BLEU'}) for k, subtask in enumerate(args.bitod_valid_subtasks): - preds, golds = [], [] + ids, preds, golds = [], [], [] for i in range(num_examples): id_ = example_ids[i] if id_.endswith(f'/{subtask}'): + ids.append(id_) preds.append(greedy[i]) golds.append(answer[i]) if golds: - metrics_to_compute = args.bitod_valid_submetrics[subtask] - sub_metrics, _ = compute_metrics(preds, golds, [metrics_to_compute], tgt_lang, args, example_ids) - subtask_metrics_dict[subtask] = (sub_metrics[metrics_to_compute], len(golds)) + metrics_to_compute = args.bitod_valid_submetrics[k] + sub_metrics, _ = compute_metrics(preds, golds, [metrics_to_compute], tgt_lang, args, ids) + subtask_metrics_dict[subtask] = (sub_metrics[metrics_to_compute], len(golds), args.bitod_valid_subweights[k]) # TODO how should we aggregate? weighted_num_examples = 0 - for subtask, (sub_result, num_ex) in subtask_metrics_dict.items(): - weight = args.bitod_valid_subweights[subtask] + for subtask, (sub_result, num_ex, weight) in subtask_metrics_dict.items(): result_key = subtask2result_key[subtask] results[result_key] += sub_result @@ -550,9 +550,11 @@ def computeBITOD(greedy, answer, tgt_lang, args, example_ids): def computeJGA(greedy, answer, example_ids): + file = open('out_dst', 'w') dataset = Bitod() hit = 0 cur_dial_id = None + assert len(example_ids) == len(greedy) == len(answer) for id_, g, a in zip(example_ids, greedy, answer): dial_id = id_.split('/')[1] if dial_id != cur_dial_id: @@ -572,6 +574,9 @@ def computeJGA(greedy, answer, example_ids): if answer_state == greedy_state: hit += 1 + file.write('TRUE' + '\n') + else: + file.write('FALSE' + '\n') return hit / len(greedy) * 100 diff --git a/genienlp/predict.py b/genienlp/predict.py index 009a77cf..f8767fcb 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -437,6 +437,7 @@ def run(args, device): model.to(device) model.eval() + task_scores = defaultdict(list) eval_dir = os.path.join(args.eval_dir, args.evaluate) os.makedirs(eval_dir, exist_ok=True) @@ -518,7 +519,15 @@ def run(args, device): prediction_file.write(line + '\n') if len(generation_output.answers) > 0: - task_scores = compute_metrics_on_file(prediction_file_name, results_file_name, task, args, tgt_lang) + compute_metrics_on_file( + task_scores, + prediction_file_name, + results_file_name, + task, + args, + tgt_lang, + confidence_scores=generation_output.confidence_scores, + ) log_final_results(args, task_scores) @@ -538,17 +547,16 @@ def log_final_results(args, task_scores): logger.info(f'\nSummary: | {sum(decaScore)} | {" | ".join([str(x) for x in decaScore])} |\n') -def compute_metrics_on_file(task_scores, pred_file, results_file_name, task, args, tgt_lang): +def compute_metrics_on_file(task_scores, pred_file, results_file_name, task, args, tgt_lang, confidence_scores=None): generation_output = GenerationOutput() - ids, contexts, preds, targets, confidence_scores = [], [], [], [], [] + ids, contexts, preds, targets = [], [], [], [] with open(pred_file) as fin: for line in fin: - id_, pred, target, context, *conf_scores = line.strip('\n').split('\t') + id_, *pred, target, context = line.strip('\n').split('\t') ids.append(id_) contexts.append(context) preds.append(pred) targets.append(target) - confidence_scores.append(conf_scores) generation_output.example_ids = ids generation_output.contexts = contexts @@ -598,9 +606,6 @@ def main(args): args.tasks = list(get_tasks(args.task_names, args).values()) logger.info(f'Arguments:\n{pformat(vars(args))}') - logger.info(f'Loading from {args.best_checkpoint}') - - devices = get_devices(args.devices) if args.override_valid_metrics: assert len(args.override_valid_metrics) == len(args.tasks) @@ -628,6 +633,9 @@ def main(args): log_final_results(args, task_scores) return + logger.info(f'Loading from {args.best_checkpoint}') + devices = get_devices(args.devices) + if len(devices) > 1: logger.info(f'Independent multi-GPU generation on following devices: {devices}') all_processes = [] From e43476a73f5b4003f3e8dd42f8bcf965221cfe57 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 24 Feb 2022 14:11:36 -0800 Subject: [PATCH 088/102] Remove unused code --- genienlp/arguments.py | 2 +- setup.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/genienlp/arguments.py b/genienlp/arguments.py index 41a9caee..a882d1bd 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -326,7 +326,7 @@ def parse_argv(parser): '--lr_schedule', type=str, default='transformer', - choices=['transformer', 'constant', 'linear', 'sgd', 'cosine', 'polynomial', 'multi_linear_cosine'], + choices=['transformer', 'constant', 'linear', 'sgd', 'cosine', 'polynomial'], help='The learning rate strategy. All of them can be used with or without warmup.', ) parser.add_argument( diff --git a/setup.py b/setup.py index be88a458..09888a97 100644 --- a/setup.py +++ b/setup.py @@ -54,14 +54,11 @@ 'requests~=2.22', 'num2words>=0.5.10', 'dateparser>=1.0.0', - 'dictdiffer~=0.9', - 'termcolor~=1.1', 'datasets==1.18.2', 'seqeval==1.2.2', 'transformers==4.15.0', 'sentence-transformers==2.1.0', 'pathos==0.2.8', - 'pytorch-lightning-spells==0.0.3', # for kf: 'kfserving>=0.5.0', 'ray==1.6.0', From 51244a3c40dd96d5f6feee941840999839609665 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 24 Feb 2022 14:33:17 -0800 Subject: [PATCH 089/102] .pre-commit-config.yaml: version update --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 09326ab1..227e5571 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.1.0 hooks: - id: check-yaml - id: end-of-file-fixer @@ -8,16 +8,16 @@ repos: - id: trailing-whitespace exclude: ^(tests/dataset/|tests/database/|tests/expected_results/) - repo: https://github.com/hadialqattan/pycln - rev: v1.0.3 + rev: v1.2.1 hooks: - id: pycln args: [--config=pyproject.toml] - repo: https://github.com/PyCQA/isort - rev: 5.9.3 + rev: 5.10.1 hooks: - id: isort - repo: https://github.com/psf/black - rev: 21.9b0 + rev: 22.1.0 hooks: - id: black - repo: https://gitlab.com/pycqa/flake8 From 44d3943ba635af6b9c51410607d24e0064ecdff7 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 24 Feb 2022 15:54:54 -0800 Subject: [PATCH 090/102] metrics: clean debug code --- genienlp/metrics.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index a7199574..c1b8dd2f 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -550,7 +550,6 @@ def computeBITOD(greedy, answer, tgt_lang, args, example_ids): def computeJGA(greedy, answer, example_ids): - file = open('out_dst', 'w') dataset = Bitod() hit = 0 cur_dial_id = None @@ -574,9 +573,6 @@ def computeJGA(greedy, answer, example_ids): if answer_state == greedy_state: hit += 1 - file.write('TRUE' + '\n') - else: - file.write('FALSE' + '\n') return hit / len(greedy) * 100 From 0df4642f9dcad794dce9694271de855d98b17fb0 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 24 Feb 2022 15:57:26 -0800 Subject: [PATCH 091/102] bitod --> e2e_dialogue --- genienlp/arguments.py | 28 ++++++++++++++++------------ genienlp/metrics.py | 18 +++++++++++------- genienlp/predict.py | 20 ++++++++++++-------- genienlp/tasks/generic_dataset.py | 2 +- genienlp/tasks/generic_task.py | 12 ++++++------ genienlp/util.py | 18 +++++++++--------- genienlp/validate.py | 20 ++++++++++---------- 7 files changed, 65 insertions(+), 53 deletions(-) diff --git a/genienlp/arguments.py b/genienlp/arguments.py index a882d1bd..6d855a9f 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -543,27 +543,31 @@ def parse_argv(parser): help='Debugging flag for hf datasets where validation will be performed on train set', ) - parser.add_argument('--bitod_e2e_evaluation', action='store_true', help='Evaluate model end-to-end') parser.add_argument( - '--bitod_valid_subtasks', + '--e2e_dialogue_evaluation', + action='store_true', + help='Evaluate model on a dialogue dataset end-to-end; i.e. model predictions are used as input instead of gold', + ) + parser.add_argument( + '--e2e_dialogue_valid_subtasks', nargs='+', type=str, default=['dst', 'api', 'da'], - help='Evaluate only on these subtasks when calculating bitod_score; rg is not included by default', + help='Evaluate only on these subtasks when calculating e2e_dialogue_score; rg is not included by default', ) parser.add_argument( - '--bitod_valid_submetrics', + '--e2e_dialogue_valid_submetrics', nargs='+', type=str, default=['jga', 'em', 'em'], - help='Specify metrics to use for each of subtasks in bitod_valid_subtasks.', + help='Specify metrics to use for each of subtasks in e2e_dialogue_valid_subtasks.', ) parser.add_argument( - '--bitod_valid_subweights', + '--e2e_dialogue_valid_subweights', nargs='+', type=float, default=[1.0, 1.0, 1.0], - help='Specify weights to use for each of subtasks in bitod_valid_subtasks.', + help='Specify weights to use for each of subtasks in e2e_dialogue_valid_subtasks.', ) @@ -654,18 +658,18 @@ def post_parse_general(args): def post_parse_train_specific(args): - if args.bitod_e2e_evaluation and args.val_batch_size[0] != 1: + if args.e2e_dialogue_evaluation and args.val_batch_size[0] != 1: logger.warning('When evaluating bitod end2end val_batch_size should be 1 so we load data turn by turn') args.val_batch_size = [1] - if len(args.bitod_valid_subtasks) != len(args.bitod_valid_submetrics): + if len(args.e2e_dialogue_valid_subtasks) != len(args.e2e_dialogue_valid_submetrics): raise ValueError( - 'Length of bitod_valid_subtasks and bitod_valid_submetrics arguments should be equal (i.e. one metric per subtask)' + 'Length of e2e_dialogue_valid_subtasks and e2e_dialogue_valid_submetrics arguments should be equal (i.e. one metric per subtask)' ) - if len(args.bitod_valid_subtasks) != len(args.bitod_valid_subweights): + if len(args.e2e_dialogue_valid_subtasks) != len(args.e2e_dialogue_valid_subweights): raise ValueError( - 'Length of bitod_valid_subtasks and bitod_valid_subweights arguments should be equal (i.e. one weight per subtask)' + 'Length of e2e_dialogue_valid_subtasks and e2e_dialogue_valid_subweights arguments should be equal (i.e. one weight per subtask)' ) if len(args.val_batch_size) < len(args.val_task_names): diff --git a/genienlp/metrics.py b/genienlp/metrics.py index c1b8dd2f..e8cb2eeb 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -518,10 +518,10 @@ def computeBITOD(greedy, answer, tgt_lang, args, example_ids): num_examples = len(answer) subtask_metrics_dict = OrderedDict() - results = OrderedDict({'bitod_score': 0.0, 'JGA': 0.0, 'API_em': 0.0, 'DA_em': 0.0, 'BLEU': 0.0}) + results = OrderedDict({'e2e_dialogue_score': 0.0, 'JGA': 0.0, 'API_em': 0.0, 'DA_em': 0.0, 'BLEU': 0.0}) subtask2result_key = OrderedDict({'dst': 'JGA', 'api': 'API_em', 'da': 'DA_em', 'rg': 'BLEU'}) - for k, subtask in enumerate(args.bitod_valid_subtasks): + for k, subtask in enumerate(args.e2e_dialogue_valid_subtasks): ids, preds, golds = [], [], [] for i in range(num_examples): id_ = example_ids[i] @@ -531,9 +531,13 @@ def computeBITOD(greedy, answer, tgt_lang, args, example_ids): golds.append(answer[i]) if golds: - metrics_to_compute = args.bitod_valid_submetrics[k] + metrics_to_compute = args.e2e_dialogue_valid_submetrics[k] sub_metrics, _ = compute_metrics(preds, golds, [metrics_to_compute], tgt_lang, args, ids) - subtask_metrics_dict[subtask] = (sub_metrics[metrics_to_compute], len(golds), args.bitod_valid_subweights[k]) + subtask_metrics_dict[subtask] = ( + sub_metrics[metrics_to_compute], + len(golds), + args.e2e_dialogue_valid_subweights[k], + ) # TODO how should we aggregate? weighted_num_examples = 0 @@ -541,10 +545,10 @@ def computeBITOD(greedy, answer, tgt_lang, args, example_ids): result_key = subtask2result_key[subtask] results[result_key] += sub_result - results['bitod_score'] += weight * (sub_result * num_ex) + results['e2e_dialogue_score'] += weight * (sub_result * num_ex) weighted_num_examples += weight * num_ex - results['bitod_score'] /= weighted_num_examples + results['e2e_dialogue_score'] /= weighted_num_examples return results @@ -596,7 +600,7 @@ def compute_metrics(greedy, answer, requested_metrics, lang, args, example_ids=N metric_values = [] if not isinstance(answer[0], list): answer = [[a] for a in answer] - if 'bitod_score' in requested_metrics: + if 'e2e_dialogue_score' in requested_metrics: requested_metrics += ['JGA', 'API_em', 'DA_em', 'BLEU'] results = computeBITOD(greedy, answer, lang, args, example_ids) metric_keys += results.keys() diff --git a/genienlp/predict.py b/genienlp/predict.py index f8767fcb..15302c4e 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -254,24 +254,28 @@ def parse_argv(parser): help='do not preserve quotation marks in the output. Useful if using alignment for semantic parsing or NLG', ) - parser.add_argument('--bitod_e2e_evaluation', action='store_true', help='Evaluate model end-to-end') parser.add_argument( - '--bitod_valid_subtasks', + '--e2e_dialogue_evaluation', + action='store_true', + help='Evaluate model on a dialogue dataset end-to-end; i.e. model predictions are used as input instead of gold', + ) + parser.add_argument( + '--e2e_dialogue_valid_subtasks', nargs='+', type=str, - help='Evaluate only on these subtasks when calculating bitod_score; rg is not included by default', + help='Evaluate only on these subtasks when calculating e2e_dialogue_score; rg is not included by default', ) parser.add_argument( - '--bitod_valid_submetrics', + '--e2e_dialogue_valid_submetrics', nargs='+', type=str, - help='Specify metrics to use for each of subtasks in bitod_valid_subtasks.', + help='Specify metrics to use for each of subtasks in e2e_dialogue_valid_subtasks.', ) parser.add_argument( - '--bitod_valid_subweights', + '--e2e_dialogue_valid_subweights', nargs='+', type=float, - help='Specify weights to use for each of subtasks in bitod_valid_subtasks.', + help='Specify weights to use for each of subtasks in e2e_dialogue_valid_subtasks.', ) @@ -282,7 +286,7 @@ def set_default_values(args): if args.confidence_feature_path is None: args.confidence_feature_path = os.path.join(args.path, 'confidence_features.pkl') - if args.bitod_e2e_evaluation and args.val_batch_size[0] != 1: + if args.e2e_dialogue_evaluation and args.val_batch_size[0] != 1: logger.warning('When evaluating bitod end2end val_batch_size should be 1 so we load data turn by turn') args.val_batch_size = [1] diff --git a/genienlp/tasks/generic_dataset.py b/genienlp/tasks/generic_dataset.py index 0adb3784..de40eeda 100644 --- a/genienlp/tasks/generic_dataset.py +++ b/genienlp/tasks/generic_dataset.py @@ -1988,7 +1988,7 @@ def __init__(self, path, *, make_example, **kwargs): super().__init__(examples, **kwargs) - # do not sort eval/ test set so we can compute individual scores for each subtask (bitod_score) + # do not sort eval/ test set so we can compute individual scores for each subtask (e2e_dialogue_score) self.eval_sort_key_fn = None # in e2e evaluation use 1 batch at a time diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 386a7496..73a9a916 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -438,7 +438,7 @@ def __init__(self, name, args): | special_tokens_v13 | special_tokens_v2_10 ) - self._metrics = ['bitod_score'] + self._metrics = ['e2e_dialogue_score'] def utterance_field(self): return 'context' @@ -462,11 +462,11 @@ def _make_example(self, turn, **kwargs): ) def get_splits(self, root, **kwargs): - kwargs['e2e_evaluation'] = self.args.bitod_e2e_evaluation + kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation return BiTODDataset.return_splits(path=root, make_example=self._make_example, **kwargs) -@register_task('bitod_nlg') +@register_task('e2e_dialogue_nlg') class BiTODNLG(BiTOD): def __init__(self, name, args): super().__init__(name, args) @@ -474,11 +474,11 @@ def __init__(self, name, args): def get_splits(self, root, **kwargs): kwargs['train_target'] = 'rg' - kwargs['e2e_evaluation'] = self.args.bitod_e2e_evaluation + kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation return BiTODDataset.return_splits(path=root, make_example=self._make_example, **kwargs) -@register_task('bitod_dst') +@register_task('e2e_dialogue_dst') class BiTODDST(BiTOD): def __init__(self, name, args): super().__init__(name, args) @@ -486,5 +486,5 @@ def __init__(self, name, args): def get_splits(self, root, **kwargs): kwargs['train_target'] = 'dst' - kwargs['e2e_evaluation'] = self.args.bitod_e2e_evaluation + kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation return BiTODDataset.return_splits(path=root, make_example=self._make_example, **kwargs) diff --git a/genienlp/util.py b/genienlp/util.py index 9f967b1f..bf22b893 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -885,8 +885,8 @@ def load_config_json(args): 'min_output_length', 'reduce_metrics', 'database_dir', - 'bitod_valid_subtasks', - 'bitod_valid_submetrics' 'bitod_valid_subweights', + 'e2e_dialogue_valid_subtasks', + 'e2e_dialogue_valid_submetrics' 'e2e_dialogue_valid_subweights', ] for o in overwrite: if o not in args or getattr(args, o) is None: @@ -897,7 +897,7 @@ def load_config_json(args): 'do_alignment', 'align_preserve_input_quotation', 'align_remove_output_quotation', - 'bitod_e2e_evaluation', + 'e2e_dialogue_evaluation', ] for o in overwrite_actions: # if argument is True in predict overwrite train; if False retrieve from train @@ -963,12 +963,12 @@ def load_config_json(args): # use default value setattr(args, r, None) - if args.bitod_valid_subtasks is None: - setattr(args, 'bitod_valid_subtasks', ['dst', 'api', 'da']) - if args.bitod_valid_submetrics is None: - setattr(args, 'bitod_valid_submetrics', ['jga', 'em', 'em']) - if args.bitod_valid_subweights is None: - setattr(args, 'bitod_valid_subweights', [1.0, 1.0, 1.0]) + if args.e2e_dialogue_valid_subtasks is None: + setattr(args, 'e2e_dialogue_valid_subtasks', ['dst', 'api', 'da']) + if args.e2e_dialogue_valid_submetrics is None: + setattr(args, 'e2e_dialogue_valid_submetrics', ['jga', 'em', 'em']) + if args.e2e_dialogue_valid_subweights is None: + setattr(args, 'e2e_dialogue_valid_subweights', [1.0, 1.0, 1.0]) # backward compatibility for models trained with genienlp before NED Refactoring (2) if args.max_features_size is None: diff --git a/genienlp/validate.py b/genienlp/validate.py index 55a1a9ad..8a5feb9c 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -62,7 +62,7 @@ def generate_with_model( disable_progbar=True, eval_dir=None, ): - if args.bitod_e2e_evaluation: + if args.e2e_dialogue_evaluation: return generate_with_seq2seq_model_for_dialogue( model, data_iterator, @@ -129,7 +129,7 @@ def generate_with_seq2seq_model_for_dialogue( """ dataset = Bitod() - bitod_preds = dict() + e2e_dialogue_preds = dict() predictions = [] example_ids = [] @@ -166,7 +166,7 @@ def generate_with_seq2seq_model_for_dialogue( new_knowledge_text = 'null' new_actions_text = 'null' active_api = None - bitod_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)} + e2e_dialogue_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)} batch_context = [] batch_tokens = numericalizer.convert_ids_to_tokens(turn.context.value.data, skip_special_tokens=False) @@ -282,7 +282,7 @@ def generate_with_seq2seq_model_for_dialogue( #### save latest state state_to_record = copy.deepcopy(dialogue_state) state_to_record = {dataset.domain2api_name(k): v for k, v in state_to_record.items()} - bitod_preds[dial_id]["turns"][str(turn_id)]["state"] = state_to_record + e2e_dialogue_preds[dial_id]["turns"][str(turn_id)]["state"] = state_to_record #### elif train_target == 'api': @@ -300,7 +300,7 @@ def generate_with_seq2seq_model_for_dialogue( dialogue_state, knowledge, api_name, numericalizer._tokenizer.src_lang, dial_id, turn_id ) #### save latest api constraints - bitod_preds[dial_id]["API"][dataset.domain2api_name(api_name)] = copy.deepcopy(constraints) + e2e_dialogue_preds[dial_id]["API"][dataset.domain2api_name(api_name)] = copy.deepcopy(constraints) #### elif do_api_call == 'no': @@ -312,22 +312,22 @@ def generate_with_seq2seq_model_for_dialogue( ) #### save latest api results - bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text + e2e_dialogue_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text #### elif train_target == 'da': new_actions_text = predictions[-1][0] #### save latest actions - bitod_preds[dial_id]["turns"][str(turn_id)]["actions"] = predictions[-1][0] + e2e_dialogue_preds[dial_id]["turns"][str(turn_id)]["actions"] = predictions[-1][0] #### elif train_target == 'rg': #### save latest response - bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = predictions[-1] + e2e_dialogue_preds[dial_id]["turns"][str(turn_id)]["response"] = predictions[-1] #### - with open(os.path.join(eval_dir, 'bitod_preds.json'), 'w') as fout: - ujson.dump(bitod_preds, fout, indent=2, ensure_ascii=False) + with open(os.path.join(eval_dir, 'e2e_dialogue_preds.json'), 'w') as fout: + ujson.dump(e2e_dialogue_preds, fout, indent=2, ensure_ascii=False) if original_order is not None: # sort back to the original order From ac0c1e7495960f24995180ef75d1ba75e9583725 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 24 Feb 2022 17:03:27 -0800 Subject: [PATCH 092/102] Abstract away bitod specific code --- genienlp/validate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/genienlp/validate.py b/genienlp/validate.py index 8a5feb9c..d8b75650 100644 --- a/genienlp/validate.py +++ b/genienlp/validate.py @@ -286,7 +286,7 @@ def generate_with_seq2seq_model_for_dialogue( #### elif train_target == 'api': - if 'HKMTR' not in active_api: + if dataset.do_knowledge_reset(active_api): new_knowledge_text = "null" knowledge = defaultdict(dict) From a2cd97b065cde96602efe686413a5929beea02f6 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 24 Feb 2022 18:08:41 -0800 Subject: [PATCH 093/102] util: fix syntax --- genienlp/util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/genienlp/util.py b/genienlp/util.py index bf22b893..c4609a74 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -886,7 +886,8 @@ def load_config_json(args): 'reduce_metrics', 'database_dir', 'e2e_dialogue_valid_subtasks', - 'e2e_dialogue_valid_submetrics' 'e2e_dialogue_valid_subweights', + 'e2e_dialogue_valid_submetrics', + 'e2e_dialogue_valid_subweights', ] for o in overwrite: if o not in args or getattr(args, o) is None: From 93de0965e22fe74257aadad561a7eae2a8bbf26e Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 24 Feb 2022 18:09:22 -0800 Subject: [PATCH 094/102] transform_dataset: shift one more... --- genienlp/paraphrase/scripts/transform_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/genienlp/paraphrase/scripts/transform_dataset.py b/genienlp/paraphrase/scripts/transform_dataset.py index 9b2df661..2dd53ef6 100644 --- a/genienlp/paraphrase/scripts/transform_dataset.py +++ b/genienlp/paraphrase/scripts/transform_dataset.py @@ -144,7 +144,7 @@ def main(args): new_queries = [] # list of lists query_file = open(args.query_file, 'r') for line in query_file: - queries = line.split('\t')[1:-1] # 0 is example id, -1 is gold answer + queries = line.split('\t')[1:-2] # 0 is example id, -1 is input, -2 is gold answer new_queries.append([lower_case(tokenize(q.strip())) for q in queries]) if args.transformation in ['remove_wrong_thingtalk', 'get_wrong_thingtalk']: gold_thingtalks = [] From eaca15923e0e368f7cfb0ff613efdba99a49fd2e Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 25 Feb 2022 09:20:03 -0800 Subject: [PATCH 095/102] do not sort dataset if sort_key_fn is None --- genienlp/data_utils/iterator.py | 2 +- genienlp/util.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/genienlp/data_utils/iterator.py b/genienlp/data_utils/iterator.py index 566d94c6..444e9e72 100644 --- a/genienlp/data_utils/iterator.py +++ b/genienlp/data_utils/iterator.py @@ -57,7 +57,7 @@ def __init__(self, data_source, batch_size, sort, shuffle_and_repeat, sort_key_f self.batch_size_fn = batch_size_fn self.groups = groups - if sort and self.sort_key: + if sort: # sort while keeping track of the original order data_with_original_order = list( zip(data_source, range(len(data_source))) diff --git a/genienlp/util.py b/genienlp/util.py index c4609a74..bb789738 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -635,7 +635,7 @@ def make_data_loader(dataset, numericalizer, batch_size, device=None, train=Fals sampler = LengthSortedIterator( all_features, batch_size=batch_size, - sort=True, + sort=bool(sort_key_fn), shuffle_and_repeat=train, sort_key_fn=sort_key_fn, batch_size_fn=batch_size_fn, From 49d336c0f367f7a8a0b8abbe320d2e5e8d0074c1 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 25 Feb 2022 09:40:40 -0800 Subject: [PATCH 096/102] Fixed bitod_nlg and bitod_dst task names --- genienlp/tasks/generic_task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/genienlp/tasks/generic_task.py b/genienlp/tasks/generic_task.py index 73a9a916..3418d6a0 100644 --- a/genienlp/tasks/generic_task.py +++ b/genienlp/tasks/generic_task.py @@ -466,7 +466,7 @@ def get_splits(self, root, **kwargs): return BiTODDataset.return_splits(path=root, make_example=self._make_example, **kwargs) -@register_task('e2e_dialogue_nlg') +@register_task('bitod_nlg') class BiTODNLG(BiTOD): def __init__(self, name, args): super().__init__(name, args) @@ -478,7 +478,7 @@ def get_splits(self, root, **kwargs): return BiTODDataset.return_splits(path=root, make_example=self._make_example, **kwargs) -@register_task('e2e_dialogue_dst') +@register_task('bitod_dst') class BiTODDST(BiTOD): def __init__(self, name, args): super().__init__(name, args) From e11cd8e340a38e6c62fb81ef6e234d4c847f4217 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 25 Feb 2022 09:41:29 -0800 Subject: [PATCH 097/102] predict: improve logging message --- genienlp/predict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/genienlp/predict.py b/genienlp/predict.py index 15302c4e..6d46dae2 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -109,7 +109,7 @@ def parse_argv(parser): parser.add_argument( '--pred_file', type=str, - help='If provided, we just compute evaluation metrics on it and bypass model prediction. File should be in tsv format with id, pred, target columns', + help='If provided, we just compute evaluation metrics on this file and bypass model prediction. File should be in tsv format with id, pred, answer columns', ) parser.add_argument( @@ -287,7 +287,7 @@ def set_default_values(args): args.confidence_feature_path = os.path.join(args.path, 'confidence_features.pkl') if args.e2e_dialogue_evaluation and args.val_batch_size[0] != 1: - logger.warning('When evaluating bitod end2end val_batch_size should be 1 so we load data turn by turn') + logger.warning('When evaluating bitod end-to-end, val_batch_size should be 1 so we load the data turn by turn') args.val_batch_size = [1] From e3bdd24a602eff3894cf714f78e301c9fdf0d8c8 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 25 Feb 2022 09:46:49 -0800 Subject: [PATCH 098/102] metrics: renaming --- genienlp/metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index e8cb2eeb..a197ed48 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -514,7 +514,7 @@ def computeDialogue(greedy, answer): return joint_goal_em, turn_request_em, turn_goal_em, answer -def computeBITOD(greedy, answer, tgt_lang, args, example_ids): +def compute_e2e_dialogue_score(greedy, answer, tgt_lang, args, example_ids): num_examples = len(answer) subtask_metrics_dict = OrderedDict() @@ -602,13 +602,13 @@ def compute_metrics(greedy, answer, requested_metrics, lang, args, example_ids=N answer = [[a] for a in answer] if 'e2e_dialogue_score' in requested_metrics: requested_metrics += ['JGA', 'API_em', 'DA_em', 'BLEU'] - results = computeBITOD(greedy, answer, lang, args, example_ids) + results = compute_e2e_dialogue_score(greedy, answer, lang, args, example_ids) metric_keys += results.keys() metric_values += results.values() if 'jga' in requested_metrics: - JGA = computeJGA(greedy, answer, example_ids) + jga = computeJGA(greedy, answer, example_ids) metric_keys += ['jga'] - metric_values += [JGA] + metric_values += [jga] if 'lfem' in requested_metrics: lfem, answer = computeLFEM(greedy, answer) metric_keys += ['lfem'] From 86044a8d9ba0dd433bafb2a559907be4e3635751 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 25 Feb 2022 10:52:48 -0800 Subject: [PATCH 099/102] Add tests for bitod tasks --- .travis.yml | 5 + tests/dataset/bitod/test.json | 256 +++++++++++++++++++++++++ tests/dataset/bitod/train.json | 256 +++++++++++++++++++++++++ tests/dataset/bitod/valid.json | 256 +++++++++++++++++++++++++ tests/expected_results/bitod/bitod.tsv | 28 +++ tests/test_e2e_dialogues.sh | 46 +++++ tests/test_main_almond.sh | 2 +- 7 files changed, 848 insertions(+), 1 deletion(-) create mode 100644 tests/dataset/bitod/test.json create mode 100644 tests/dataset/bitod/train.json create mode 100644 tests/dataset/bitod/valid.json create mode 100644 tests/expected_results/bitod/bitod.tsv create mode 100755 tests/test_e2e_dialogues.sh diff --git a/.travis.yml b/.travis.yml index cbe7031e..9a6d55f7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -40,6 +40,11 @@ jobs: stage: test script: - bash ./tests/test_translation.sh + - + name: "E2E Dialogues tests" + stage: test + script: + - bash ./tests/test_e2e_dialogues.sh - name: "NED tests" stage: test diff --git a/tests/dataset/bitod/test.json b/tests/dataset/bitod/test.json new file mode 100644 index 00000000..ad9227b2 --- /dev/null +++ b/tests/dataset/bitod/test.json @@ -0,0 +1,256 @@ +{ + "data": [ + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 1, + "dialog_history": "USER: Hey, can you help me with hotel booking?", + "input_text": "DST: null USER: Hey, can you help me with hotel booking? ", + "output_text": "( hotels search )", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 1, + "dialog_history": "USER: Hey, can you help me with hotel booking?", + "input_text": "API: null ( hotels search ) USER: Hey, can you help me with hotel booking? ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 1, + "dialog_history": "USER: Hey, can you help me with hotel booking?", + "input_text": "DA: null ( hotels search ) USER: Hey, can you help me with hotel booking? ", + "output_text": "( hotels search ) request location , request rating", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 1, + "dialog_history": "USER: Hey, can you help me with hotel booking?", + "input_text": "RG: ( hotels search ) request location , request rating USER: Hey, can you help me with hotel booking? ", + "output_text": "Sure! Do you have a destination in mind, and did you have a certain hotel rating in mind?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 2, + "dialog_history": "AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars.", + "input_text": "DST: ( hotels search ) AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. ", + "output_text": "( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 2, + "dialog_history": "AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars.", + "input_text": "API: null ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 2, + "dialog_history": "AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars.", + "input_text": "DA: null ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" AGENT_DA: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. ", + "output_text": "( hotels search ) request stars", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 2, + "dialog_history": "AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars.", + "input_text": "RG: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. ", + "output_text": "How many stars would be your ideal?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 3, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars.", + "input_text": "DST: ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars. ", + "output_text": "( hotels search ) stars equal_to \" 5 \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 3, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars.", + "input_text": "API: null ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 3, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars.", + "input_text": "DA: null ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_DA: ( hotels search ) request stars USER: 5 stars. ", + "output_text": "( hotels search ) request price_level", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 3, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars.", + "input_text": "RG: ( hotels search ) request price_level AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars. ", + "output_text": "What's your budget?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 4, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price.", + "input_text": "DST: ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price. ", + "output_text": "( hotels search ) price_level equal_to \" moderate \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 4, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price.", + "input_text": "API: null ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price. ", + "output_text": "yes", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 4, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price.", + "input_text": "DA: ( hotels search ) available_options \" 26 \" , location \" Mong Kok | Kowloon | Yau Tsim Mong District \" , name \" Cordis, Hong Kong \" , price_level \" moderate \" , price_per_night \" 839 HKD \" , rating \" 10 \" , stars \" 5 \" ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_DA: ( hotels search ) request price_level USER: I'm looking for a moderate room price. ", + "output_text": "( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \"", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 4, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price.", + "input_text": "RG: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price. ", + "output_text": "We have 26 hotels with availability here, and my recommendation would be Cordis, Hong Kong which has a 10 star rating.", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 5, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd.", + "input_text": "DST: ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd. ", + "output_text": "( hotels booking ) name equal_to \" Cordis, Hong Kong \" , start_day equal_to \" 23 \" , start_month equal_to \" 11 \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 5, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd.", + "input_text": "API: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights #unknown , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name #unknown , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 5, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd.", + "input_text": "DA: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights #unknown , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name #unknown , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_DA: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd. ", + "output_text": "( hotels booking ) request number_of_nights", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 5, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd.", + "input_text": "RG: ( hotels booking ) request number_of_nights AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd. ", + "output_text": "How many nights do you plan to stay?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 6, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David.", + "input_text": "DST: ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights #unknown , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name #unknown , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. ", + "output_text": "( hotels booking ) number_of_nights equal_to \" 9 \" , user_name equal_to \" David \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 6, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David.", + "input_text": "API: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 6, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David.", + "input_text": "DA: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_DA: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. ", + "output_text": "( hotels booking ) request number_of_rooms", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 6, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David.", + "input_text": "RG: ( hotels booking ) request number_of_rooms AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. ", + "output_text": "Of course, David. How many rooms?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 7, + "dialog_history": "AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please.", + "input_text": "DST: ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please. ", + "output_text": "( hotels booking ) number_of_rooms equal_to \" eight \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 7, + "dialog_history": "AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please.", + "input_text": "API: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms equal_to \" eight \" , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 7, + "dialog_history": "AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please.", + "input_text": "DA: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms equal_to \" eight \" , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_DA: ( hotels booking ) request number_of_rooms USER: eight rooms please. ", + "output_text": "( hotels booking ) confirm name equal_to \" Cordis, Hong Kong \" , confirm number_of_nights equal_to \" 9 \" , confirm number_of_rooms equal_to \" eight \" , confirm start_date equal_to \" November 23 \"", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 7, + "dialog_history": "AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please.", + "input_text": "RG: ( hotels booking ) confirm name equal_to \" Cordis, Hong Kong \" , confirm number_of_nights equal_to \" 9 \" , confirm number_of_rooms equal_to \" eight \" , confirm start_date equal_to \" November 23 \" AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please. ", + "output_text": "Okay, just to confirm: You would like to book eight rooms at the Cordis, Hong Kong hotel for 9 nights with a check-in date of November 23rd?", + "train_target": "rg" + } + ] +} \ No newline at end of file diff --git a/tests/dataset/bitod/train.json b/tests/dataset/bitod/train.json new file mode 100644 index 00000000..ad9227b2 --- /dev/null +++ b/tests/dataset/bitod/train.json @@ -0,0 +1,256 @@ +{ + "data": [ + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 1, + "dialog_history": "USER: Hey, can you help me with hotel booking?", + "input_text": "DST: null USER: Hey, can you help me with hotel booking? ", + "output_text": "( hotels search )", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 1, + "dialog_history": "USER: Hey, can you help me with hotel booking?", + "input_text": "API: null ( hotels search ) USER: Hey, can you help me with hotel booking? ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 1, + "dialog_history": "USER: Hey, can you help me with hotel booking?", + "input_text": "DA: null ( hotels search ) USER: Hey, can you help me with hotel booking? ", + "output_text": "( hotels search ) request location , request rating", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 1, + "dialog_history": "USER: Hey, can you help me with hotel booking?", + "input_text": "RG: ( hotels search ) request location , request rating USER: Hey, can you help me with hotel booking? ", + "output_text": "Sure! Do you have a destination in mind, and did you have a certain hotel rating in mind?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 2, + "dialog_history": "AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars.", + "input_text": "DST: ( hotels search ) AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. ", + "output_text": "( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 2, + "dialog_history": "AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars.", + "input_text": "API: null ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 2, + "dialog_history": "AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars.", + "input_text": "DA: null ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" AGENT_DA: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. ", + "output_text": "( hotels search ) request stars", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 2, + "dialog_history": "AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars.", + "input_text": "RG: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. ", + "output_text": "How many stars would be your ideal?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 3, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars.", + "input_text": "DST: ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars. ", + "output_text": "( hotels search ) stars equal_to \" 5 \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 3, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars.", + "input_text": "API: null ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 3, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars.", + "input_text": "DA: null ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_DA: ( hotels search ) request stars USER: 5 stars. ", + "output_text": "( hotels search ) request price_level", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 3, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars.", + "input_text": "RG: ( hotels search ) request price_level AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars. ", + "output_text": "What's your budget?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 4, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price.", + "input_text": "DST: ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price. ", + "output_text": "( hotels search ) price_level equal_to \" moderate \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 4, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price.", + "input_text": "API: null ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price. ", + "output_text": "yes", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 4, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price.", + "input_text": "DA: ( hotels search ) available_options \" 26 \" , location \" Mong Kok | Kowloon | Yau Tsim Mong District \" , name \" Cordis, Hong Kong \" , price_level \" moderate \" , price_per_night \" 839 HKD \" , rating \" 10 \" , stars \" 5 \" ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_DA: ( hotels search ) request price_level USER: I'm looking for a moderate room price. ", + "output_text": "( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \"", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 4, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price.", + "input_text": "RG: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price. ", + "output_text": "We have 26 hotels with availability here, and my recommendation would be Cordis, Hong Kong which has a 10 star rating.", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 5, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd.", + "input_text": "DST: ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd. ", + "output_text": "( hotels booking ) name equal_to \" Cordis, Hong Kong \" , start_day equal_to \" 23 \" , start_month equal_to \" 11 \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 5, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd.", + "input_text": "API: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights #unknown , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name #unknown , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 5, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd.", + "input_text": "DA: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights #unknown , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name #unknown , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_DA: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd. ", + "output_text": "( hotels booking ) request number_of_nights", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 5, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd.", + "input_text": "RG: ( hotels booking ) request number_of_nights AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd. ", + "output_text": "How many nights do you plan to stay?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 6, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David.", + "input_text": "DST: ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights #unknown , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name #unknown , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. ", + "output_text": "( hotels booking ) number_of_nights equal_to \" 9 \" , user_name equal_to \" David \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 6, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David.", + "input_text": "API: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 6, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David.", + "input_text": "DA: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_DA: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. ", + "output_text": "( hotels booking ) request number_of_rooms", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 6, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David.", + "input_text": "RG: ( hotels booking ) request number_of_rooms AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. ", + "output_text": "Of course, David. How many rooms?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 7, + "dialog_history": "AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please.", + "input_text": "DST: ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please. ", + "output_text": "( hotels booking ) number_of_rooms equal_to \" eight \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 7, + "dialog_history": "AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please.", + "input_text": "API: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms equal_to \" eight \" , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 7, + "dialog_history": "AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please.", + "input_text": "DA: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms equal_to \" eight \" , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_DA: ( hotels booking ) request number_of_rooms USER: eight rooms please. ", + "output_text": "( hotels booking ) confirm name equal_to \" Cordis, Hong Kong \" , confirm number_of_nights equal_to \" 9 \" , confirm number_of_rooms equal_to \" eight \" , confirm start_date equal_to \" November 23 \"", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 7, + "dialog_history": "AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please.", + "input_text": "RG: ( hotels booking ) confirm name equal_to \" Cordis, Hong Kong \" , confirm number_of_nights equal_to \" 9 \" , confirm number_of_rooms equal_to \" eight \" , confirm start_date equal_to \" November 23 \" AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please. ", + "output_text": "Okay, just to confirm: You would like to book eight rooms at the Cordis, Hong Kong hotel for 9 nights with a check-in date of November 23rd?", + "train_target": "rg" + } + ] +} \ No newline at end of file diff --git a/tests/dataset/bitod/valid.json b/tests/dataset/bitod/valid.json new file mode 100644 index 00000000..ad9227b2 --- /dev/null +++ b/tests/dataset/bitod/valid.json @@ -0,0 +1,256 @@ +{ + "data": [ + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 1, + "dialog_history": "USER: Hey, can you help me with hotel booking?", + "input_text": "DST: null USER: Hey, can you help me with hotel booking? ", + "output_text": "( hotels search )", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 1, + "dialog_history": "USER: Hey, can you help me with hotel booking?", + "input_text": "API: null ( hotels search ) USER: Hey, can you help me with hotel booking? ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 1, + "dialog_history": "USER: Hey, can you help me with hotel booking?", + "input_text": "DA: null ( hotels search ) USER: Hey, can you help me with hotel booking? ", + "output_text": "( hotels search ) request location , request rating", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 1, + "dialog_history": "USER: Hey, can you help me with hotel booking?", + "input_text": "RG: ( hotels search ) request location , request rating USER: Hey, can you help me with hotel booking? ", + "output_text": "Sure! Do you have a destination in mind, and did you have a certain hotel rating in mind?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 2, + "dialog_history": "AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars.", + "input_text": "DST: ( hotels search ) AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. ", + "output_text": "( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 2, + "dialog_history": "AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars.", + "input_text": "API: null ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 2, + "dialog_history": "AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars.", + "input_text": "DA: null ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" AGENT_DA: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. ", + "output_text": "( hotels search ) request stars", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 2, + "dialog_history": "AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars.", + "input_text": "RG: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. ", + "output_text": "How many stars would be your ideal?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 3, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars.", + "input_text": "DST: ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars. ", + "output_text": "( hotels search ) stars equal_to \" 5 \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 3, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars.", + "input_text": "API: null ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 3, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars.", + "input_text": "DA: null ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_DA: ( hotels search ) request stars USER: 5 stars. ", + "output_text": "( hotels search ) request price_level", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 3, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars.", + "input_text": "RG: ( hotels search ) request price_level AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars. ", + "output_text": "What's your budget?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 4, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price.", + "input_text": "DST: ( hotels search ) location equal_to \" don't care \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price. ", + "output_text": "( hotels search ) price_level equal_to \" moderate \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 4, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price.", + "input_text": "API: null ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price. ", + "output_text": "yes", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 4, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price.", + "input_text": "DA: ( hotels search ) available_options \" 26 \" , location \" Mong Kok | Kowloon | Yau Tsim Mong District \" , name \" Cordis, Hong Kong \" , price_level \" moderate \" , price_per_night \" 839 HKD \" , rating \" 10 \" , stars \" 5 \" ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_DA: ( hotels search ) request price_level USER: I'm looking for a moderate room price. ", + "output_text": "( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \"", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels search", + "turn_id": 4, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price.", + "input_text": "RG: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price. ", + "output_text": "We have 26 hotels with availability here, and my recommendation would be Cordis, Hong Kong which has a 10 star rating.", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 5, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd.", + "input_text": "DST: ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd. ", + "output_text": "( hotels booking ) name equal_to \" Cordis, Hong Kong \" , start_day equal_to \" 23 \" , start_month equal_to \" 11 \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 5, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd.", + "input_text": "API: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights #unknown , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name #unknown , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 5, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd.", + "input_text": "DA: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights #unknown , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name #unknown , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_DA: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd. ", + "output_text": "( hotels booking ) request number_of_nights", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 5, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd.", + "input_text": "RG: ( hotels booking ) request number_of_nights AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" USER: I'll take it. I need to check in on November 23rd. ", + "output_text": "How many nights do you plan to stay?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 6, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David.", + "input_text": "DST: ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights #unknown , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name #unknown , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. ", + "output_text": "( hotels booking ) number_of_nights equal_to \" 9 \" , user_name equal_to \" David \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 6, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David.", + "input_text": "API: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 6, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David.", + "input_text": "DA: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_DA: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. ", + "output_text": "( hotels booking ) request number_of_rooms", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 6, + "dialog_history": "AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David.", + "input_text": "RG: ( hotels booking ) request number_of_rooms AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to \" 26 \" , offer name equal_to \" Cordis, Hong Kong \" , offer rating equal_to \" 10 \" AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. ", + "output_text": "Of course, David. How many rooms?", + "train_target": "rg" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 7, + "dialog_history": "AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please.", + "input_text": "DST: ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms #unknown , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please. ", + "output_text": "( hotels booking ) number_of_rooms equal_to \" eight \"", + "train_target": "dst" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 7, + "dialog_history": "AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please.", + "input_text": "API: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms equal_to \" eight \" , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please. ", + "output_text": "no", + "train_target": "api" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 7, + "dialog_history": "AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please.", + "input_text": "DA: null ( hotels booking ) name equal_to \" Cordis, Hong Kong \" , number_of_nights equal_to \" 9 \" , number_of_rooms equal_to \" eight \" , start_day equal_to \" 23 \" , start_month equal_to \" 11 \" , user_name equal_to \" David \" , ( hotels search ) location equal_to \" don't care \" , price_level equal_to \" moderate \" , rating at_least \" 4 \" , stars equal_to \" 5 \" AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_DA: ( hotels booking ) request number_of_rooms USER: eight rooms please. ", + "output_text": "( hotels booking ) confirm name equal_to \" Cordis, Hong Kong \" , confirm number_of_nights equal_to \" 9 \" , confirm number_of_rooms equal_to \" eight \" , confirm start_date equal_to \" November 23 \"", + "train_target": "da" + }, + { + "dial_id": "ec909032-d5fd-4dee-8dd2-53a895a98824", + "task": "hotels booking", + "turn_id": 7, + "dialog_history": "AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please.", + "input_text": "RG: ( hotels booking ) confirm name equal_to \" Cordis, Hong Kong \" , confirm number_of_nights equal_to \" 9 \" , confirm number_of_rooms equal_to \" eight \" , confirm start_date equal_to \" November 23 \" AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please. ", + "output_text": "Okay, just to confirm: You would like to book eight rooms at the Cordis, Hong Kong hotel for 9 nights with a check-in date of November 23rd?", + "train_target": "rg" + } + ] +} \ No newline at end of file diff --git a/tests/expected_results/bitod/bitod.tsv b/tests/expected_results/bitod/bitod.tsv new file mode 100644 index 00000000..641d478a --- /dev/null +++ b/tests/expected_results/bitod/bitod.tsv @@ -0,0 +1,28 @@ +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/1/dst up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels search ) DST: null USER: Hey, can you help me with hotel booking? +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/1/api up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn no API: null ( hotels search ) USER: Hey, can you help me with hotel booking? +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/1/da up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels search ) request location , request rating DA: null ( hotels search ) USER: Hey, can you help me with hotel booking? +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/1/rg up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn Sure! Do you have a destination in mind, and did you have a certain hotel rating in mind? RG: ( hotels search ) request location , request rating USER: Hey, can you help me with hotel booking? +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/2/dst up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels search ) location equal_to " don't care " , rating at_least " 4 " DST: ( hotels search ) AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/2/api up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn no API: null ( hotels search ) location equal_to " don't care " , rating at_least " 4 " AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/2/da up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels search ) request stars DA: null ( hotels search ) location equal_to " don't care " , rating at_least " 4 " AGENT_DA: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/2/rg up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn How many stars would be your ideal? RG: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request location , request rating USER: No specific destination, but I do want a hotel that's at least 4 stars. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/3/dst up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels search ) stars equal_to " 5 " DST: ( hotels search ) location equal_to " don't care " , rating at_least " 4 " AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/3/api up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn no API: null ( hotels search ) location equal_to " don't care " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/3/da up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels search ) request price_level DA: null ( hotels search ) location equal_to " don't care " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_DA: ( hotels search ) request stars USER: 5 stars. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/3/rg up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn What's your budget? RG: ( hotels search ) request price_level AGENT_ACTS_PREV: ( hotels search ) request location , request rating AGENT_ACTS: ( hotels search ) request stars USER: 5 stars. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/4/dst up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels search ) price_level equal_to " moderate " DST: ( hotels search ) location equal_to " don't care " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/4/api up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn yes API: null ( hotels search ) location equal_to " don't care " , price_level equal_to " moderate " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/4/da up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels search ) offer available_options equal_to " 26 " , offer name equal_to " Cordis, Hong Kong " , offer rating equal_to " 10 " DA: ( hotels search ) available_options " 26 " , location " Mong Kok | Kowloon | Yau Tsim Mong District " , name " Cordis, Hong Kong " , price_level " moderate " , price_per_night " 839 HKD " , rating " 10 " , stars " 5 " ( hotels search ) location equal_to " don't care " , price_level equal_to " moderate " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_DA: ( hotels search ) request price_level USER: I'm looking for a moderate room price. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/4/rg up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn We have 26 hotels with availability here, and my recommendation would be Cordis, Hong Kong which has a 10 star rating. RG: ( hotels search ) offer available_options equal_to " 26 " , offer name equal_to " Cordis, Hong Kong " , offer rating equal_to " 10 " AGENT_ACTS_PREV: ( hotels search ) request stars AGENT_ACTS: ( hotels search ) request price_level USER: I'm looking for a moderate room price. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/5/dst up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels booking ) name equal_to " Cordis, Hong Kong " , start_day equal_to " 23 " , start_month equal_to " 11 " DST: ( hotels search ) location equal_to " don't care " , price_level equal_to " moderate " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to " 26 " , offer name equal_to " Cordis, Hong Kong " , offer rating equal_to " 10 " USER: I'll take it. I need to check in on November 23rd. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/5/api up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn no API: null ( hotels booking ) name equal_to " Cordis, Hong Kong " , number_of_nights #unknown , number_of_rooms #unknown , start_day equal_to " 23 " , start_month equal_to " 11 " , user_name #unknown , ( hotels search ) location equal_to " don't care " , price_level equal_to " moderate " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to " 26 " , offer name equal_to " Cordis, Hong Kong " , offer rating equal_to " 10 " USER: I'll take it. I need to check in on November 23rd. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/5/da up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels booking ) request number_of_nights DA: null ( hotels booking ) name equal_to " Cordis, Hong Kong " , number_of_nights #unknown , number_of_rooms #unknown , start_day equal_to " 23 " , start_month equal_to " 11 " , user_name #unknown , ( hotels search ) location equal_to " don't care " , price_level equal_to " moderate " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_DA: ( hotels search ) offer available_options equal_to " 26 " , offer name equal_to " Cordis, Hong Kong " , offer rating equal_to " 10 " USER: I'll take it. I need to check in on November 23rd. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/5/rg up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn How many nights do you plan to stay? RG: ( hotels booking ) request number_of_nights AGENT_ACTS_PREV: ( hotels search ) request price_level AGENT_ACTS: ( hotels search ) offer available_options equal_to " 26 " , offer name equal_to " Cordis, Hong Kong " , offer rating equal_to " 10 " USER: I'll take it. I need to check in on November 23rd. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/6/dst up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels booking ) number_of_nights equal_to " 9 " , user_name equal_to " David " DST: ( hotels booking ) name equal_to " Cordis, Hong Kong " , number_of_nights #unknown , number_of_rooms #unknown , start_day equal_to " 23 " , start_month equal_to " 11 " , user_name #unknown , ( hotels search ) location equal_to " don't care " , price_level equal_to " moderate " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to " 26 " , offer name equal_to " Cordis, Hong Kong " , offer rating equal_to " 10 " AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/6/api up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn no API: null ( hotels booking ) name equal_to " Cordis, Hong Kong " , number_of_nights equal_to " 9 " , number_of_rooms #unknown , start_day equal_to " 23 " , start_month equal_to " 11 " , user_name equal_to " David " , ( hotels search ) location equal_to " don't care " , price_level equal_to " moderate " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to " 26 " , offer name equal_to " Cordis, Hong Kong " , offer rating equal_to " 10 " AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/6/da up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels booking ) request number_of_rooms DA: null ( hotels booking ) name equal_to " Cordis, Hong Kong " , number_of_nights equal_to " 9 " , number_of_rooms #unknown , start_day equal_to " 23 " , start_month equal_to " 11 " , user_name equal_to " David " , ( hotels search ) location equal_to " don't care " , price_level equal_to " moderate " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to " 26 " , offer name equal_to " Cordis, Hong Kong " , offer rating equal_to " 10 " AGENT_DA: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/6/rg up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn Of course, David. How many rooms? RG: ( hotels booking ) request number_of_rooms AGENT_ACTS_PREV: ( hotels search ) offer available_options equal_to " 26 " , offer name equal_to " Cordis, Hong Kong " , offer rating equal_to " 10 " AGENT_ACTS: ( hotels booking ) request number_of_nights USER: I need 9 nights booked. My name is David. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/7/dst up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels booking ) number_of_rooms equal_to " eight " DST: ( hotels booking ) name equal_to " Cordis, Hong Kong " , number_of_nights equal_to " 9 " , number_of_rooms #unknown , start_day equal_to " 23 " , start_month equal_to " 11 " , user_name equal_to " David " , ( hotels search ) location equal_to " don't care " , price_level equal_to " moderate " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/7/api up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn no API: null ( hotels booking ) name equal_to " Cordis, Hong Kong " , number_of_nights equal_to " 9 " , number_of_rooms equal_to " eight " , start_day equal_to " 23 " , start_month equal_to " 11 " , user_name equal_to " David " , ( hotels search ) location equal_to " don't care " , price_level equal_to " moderate " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/7/da up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn ( hotels booking ) confirm name equal_to " Cordis, Hong Kong " , confirm number_of_nights equal_to " 9 " , confirm number_of_rooms equal_to " eight " , confirm start_date equal_to " November 23 " DA: null ( hotels booking ) name equal_to " Cordis, Hong Kong " , number_of_nights equal_to " 9 " , number_of_rooms equal_to " eight " , start_day equal_to " 23 " , start_month equal_to " 11 " , user_name equal_to " David " , ( hotels search ) location equal_to " don't care " , price_level equal_to " moderate " , rating at_least " 4 " , stars equal_to " 5 " AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_DA: ( hotels booking ) request number_of_rooms USER: eight rooms please. +bitod/ec909032-d5fd-4dee-8dd2-53a895a98824/7/rg up Trend Trend Trend Trend Trend Trend Trend Trend Trend sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship sponsorship metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metro metroEarnEarnEarn Okay, just to confirm: You would like to book eight rooms at the Cordis, Hong Kong hotel for 9 nights with a check-in date of November 23rd? RG: ( hotels booking ) confirm name equal_to " Cordis, Hong Kong " , confirm number_of_nights equal_to " 9 " , confirm number_of_rooms equal_to " eight " , confirm start_date equal_to " November 23 " AGENT_ACTS_PREV: ( hotels booking ) request number_of_nights AGENT_ACTS: ( hotels booking ) request number_of_rooms USER: eight rooms please. diff --git a/tests/test_e2e_dialogues.sh b/tests/test_e2e_dialogues.sh new file mode 100755 index 00000000..13ad3e90 --- /dev/null +++ b/tests/test_e2e_dialogues.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +. ./tests/lib.sh + +# test e2e dialogue tasks + +hparams=( + "--pretrained_model sshleifer/bart-tiny-random" + "--pretrained_model sshleifer/bart-tiny-random" + ) +tasks=( + bitod + bitod_dst + ) + +for i in ${!hparams[*]}; +do + # train + genienlp train --train_tasks ${tasks[i]} --train_batch_tokens 100 --val_batch_size 300 --train_iterations 4 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/bitod --exist_ok --skip_cache --embeddings $EMBEDDING_DIR --no_commit ${hparams[i]} + + # greedy prediction + genienlp predict --tasks ${tasks[i]} --evaluate test --path $workdir/model_$i --overwrite --eval_dir $workdir/model_$i/eval_results/ --data $SRCDIR/dataset/bitod --embeddings $EMBEDDING_DIR --skip_cache --extra_metrics e2e_dialogue_score + + # check if result file exists + if test ! -f $workdir/model_$i/eval_results/test/${tasks[i]}.tsv ; then + echo "File not found!" + exit 1 + fi + + # check export and server mode + if [ $i == 0 ] ; then + echo "Testing export" + genienlp export --path $workdir/model_$i --output $workdir/model_"$i"_exported + + echo "Testing the server mode" + echo '{"id": "dummy_example_1", "context": "show me .", "question": "translate to thingtalk", "answer": "now => () => notify"}' | genienlp server --path $workdir/model_$i --stdin + fi + + if [ $i == 0 ] ; then + # check if predictions matches expected_results + diff -u $SRCDIR/expected_results/bitod/bitod.tsv $workdir/model_$i/eval_results/test/bitod.tsv + fi + + rm -rf $workdir/model_$i $workdir/model_"$i"_exported + +done diff --git a/tests/test_main_almond.sh b/tests/test_main_almond.sh index a96aa8cf..6b4cf31d 100755 --- a/tests/test_main_almond.sh +++ b/tests/test_main_almond.sh @@ -28,7 +28,7 @@ do # check TransformerSeq2Seq and TransformerLSTM if [ $i == 0 ] || [ $i == 2 ] ; then echo "Testing export" - genienlp export --path $workdir/model_$i --output $workdir/model_$i_exported + genienlp export --path $workdir/model_$i --output $workdir/model_"$i"_exported echo "Testing the server mode" echo '{"id": "dummy_example_1", "context": "show me .", "question": "translate to thingtalk", "answer": "now => () => notify"}' | genienlp server --path $workdir/model_$i --stdin From 065f9784da1b3e34f822d34d6b86879952d8cfa6 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 25 Feb 2022 13:43:24 -0800 Subject: [PATCH 100/102] metrics: small fixes --- genienlp/metrics.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/genienlp/metrics.py b/genienlp/metrics.py index b47459b7..aec0463e 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -53,7 +53,7 @@ # metrics that are calculated over a corpus (i.e. a list of predictions and gold answers, not single ones). # These metrics cannot be calculated on individual examples and then averaged. -corpus_level_metrics = {'bleu', 'casedbleu', 'ter', 't5_bleu', 'nmt_bleu', 'corpus_f1'} +corpus_level_metrics = {'bleu', 'casedbleu', 'ter', 't5_bleu', 'nmt_bleu', 'corpus_f1', 'jga'} def to_lf(s, table): @@ -764,15 +764,13 @@ def calculate_and_reduce_metrics(generation_output, metrics_to_compute, args, la for m in metrics_to_compute: if m in corpus_level_metrics: logging.warning( - 'You are using the corpus-level metric %s with `--reduce_metrics top_k`, which can lead to incorrect results.', - m, + f'You are using the corpus-level metric {m} with `--reduce_metrics top_k`, which can lead to incorrect results.', ) - for i in range(len(predictions)): # for each input example_metrics = OrderedDict() # keep track of metrics for one input and all of its outputs for j in range(len(predictions[i])): # for each output (in case of multiple outputs) partial_metrics = compute_metrics( - [predictions[i][j]], [answers[i]], metrics_to_compute, lang + [predictions[i][j]], [answers[i]], metrics_to_compute, lang, args, example_ids ) # calculate the metric on the j-th output of the i-th input for k, v in partial_metrics.items(): example_metrics[k] = max(example_metrics.get(k, 0), v) From e07c8c52167fed461faadec2bf8ae9bd0873ab77 Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 25 Feb 2022 14:01:41 -0800 Subject: [PATCH 101/102] Update test results after adding input column to predictions --- .../almond/bert_base_cased_beam.tsv | 20 +++++++++---------- .../sequence_classification/ood_task.tsv | 20 +++++++++---------- .../token_classification/conll2003_2.tsv | 10 +++++----- .../token_classification/cross_ner_news_0.tsv | 10 +++++----- .../token_classification/cross_ner_news_1.tsv | 10 +++++----- 5 files changed, 35 insertions(+), 35 deletions(-) diff --git a/tests/expected_results/almond/bert_base_cased_beam.tsv b/tests/expected_results/almond/bert_base_cased_beam.tsv index 5c4fe7cd..d1d972f9 100644 --- a/tests/expected_results/almond/bert_base_cased_beam.tsv +++ b/tests/expected_results/almond/bert_base_cased_beam.tsv @@ -1,10 +1,10 @@ -almond/R1633686-0 i engligli @com.twitter . home_timeline ( ) => @com.gmail . send_email ( message = text ) ; -almond/R1624112-0 onglitalgli room monitor ( @thermostat . get_humidity ( ) ) filter value <= NUMBER_0 => @org.thingpedia.bluetooth.speaker.a2dp . play_music ( ) ; -almond/R1633355-0 ##tter igligli hope please youwi monitor ( @security-camera . current_event ( ) ) => @com.twitter . post_picture ( caption = " person and i hope you have a great day " , picture_url = picture_url ) ; -almond/R1633055-0 once monitor ( @org.thingpedia.weather . current ( location = LOCATION_0 ) ) filter ! ( status == enum snowy ) => @com.fitbit . getsteps ( ) ; -almond/R1621376-0 i igliglisetglital aost the thingtal get aost the thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get monitor ( @com.twitter . home_timeline ( ) ) => @thermostat . set_target_temperature ( ) ; -almond/R1628268-0 englitag monitor ( @com.washingtonpost . get_article ( section = enum world ) ) => @com.instagram . get_pictures ( ) filter filter == null ^^com.instagram:filter ( " brannan " ) ; -almond/R1632097-0 monitor ( @org.thingpedia.weather . current ( location = $location . work ) ) filter ! ( status == enum snowy ) => @org.thingpedia.builtin.thingengine.phone . get_gps ( ) ; -almond/R1618372-0 ##eriaglieriagligligligligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli me monitor ( @com.tumblr . new_photo ( blog_name = " zackeriah " ^^tt:username ) ) => @org.thingpedia.icalendar . list_events ( ) ; -almond/R1624288-0 i igligliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igli monitor ( @com.twitter . home_timeline ( ) ) => @thermostat . set_target_temperature ( value = NUMBER_0 F ) ; -almond/R1629029-0 i igli tell thing to thingtalwitter thing totter thing to totter thing to thing totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing monitor ( @org.thingpedia.icalendar . list_events ( ) filter organizer == " jouanna mills " ) filter status == enum started => @com.twitter . direct_messages ( ) ; +almond/R1633686-0 i engligli @com.twitter . home_timeline ( ) => @com.gmail . send_email ( message = text ) ; deliver the text of tweets from anyone i follow via email translate from english to thingtalk +almond/R1624112-0 onglitalgli room monitor ( @thermostat . get_humidity ( ) ) filter value <= NUMBER_0 => @org.thingpedia.bluetooth.speaker.a2dp . play_music ( ) ; when the room 's humidity falls below NUMBER_0 in the room play something on my speaker translate from english to thingtalk +almond/R1633355-0 ##tter igligli hope please youwi monitor ( @security-camera . current_event ( ) ) => @com.twitter . post_picture ( caption = " person and i hope you have a great day " , picture_url = picture_url ) ; please post the live feed photo of my new security camera on twitter and caption it person and i hope you have a great day translate from english to thingtalk +almond/R1633055-0 once monitor ( @org.thingpedia.weather . current ( location = LOCATION_0 ) ) filter ! ( status == enum snowy ) => @com.fitbit . getsteps ( ) ; once it stops snowing in LOCATION_0 , get the steps from fitbit translate from english to thingtalk +almond/R1621376-0 i igliglisetglital aost the thingtal get aost the thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get thingtal get monitor ( @com.twitter . home_timeline ( ) ) => @thermostat . set_target_temperature ( ) ; reset my thermostat reading if i get a tweet translate from english to thingtalk +almond/R1628268-0 englitag monitor ( @com.washingtonpost . get_article ( section = enum world ) ) => @com.instagram . get_pictures ( ) filter filter == null ^^com.instagram:filter ( " brannan " ) ; notify me of the instagram pictures with filter brannan when there is a new article in the world section of the washington post translate from english to thingtalk +almond/R1632097-0 monitor ( @org.thingpedia.weather . current ( location = $location . work ) ) filter ! ( status == enum snowy ) => @org.thingpedia.builtin.thingengine.phone . get_gps ( ) ; retrieve my current location when it stops snowing in work translate from english to thingtalk +almond/R1618372-0 ##eriaglieriagligligligligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli megligligligligligligligligligli me monitor ( @com.tumblr . new_photo ( blog_name = " zackeriah " ^^tt:username ) ) => @org.thingpedia.icalendar . list_events ( ) ; show me my icalendar events when zackeriah uploads a new picture on tumblr translate from english to thingtalk +almond/R1624288-0 i igligliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igliglital to to thing to NUMBER_0 NUMBER_0 to NUMBER_0 to NUMBER_0 NUMBER_0 igli monitor ( @com.twitter . home_timeline ( ) ) => @thermostat . set_target_temperature ( value = NUMBER_0 F ) ; change my thermostat to NUMBER_0 degrees when someone i follow makes a tweet translate from english to thingtalk +almond/R1629029-0 i igli tell thing to thingtalwitter thing totter thing to totter thing to thing totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing to totter thing to totter thing to thing totter thing monitor ( @org.thingpedia.icalendar . list_events ( ) filter organizer == " jouanna mills " ) filter status == enum started => @com.twitter . direct_messages ( ) ; tell me the direct message i got on twitter when an event organized by jouanna mills from my calendar is about to start translate from english to thingtalk diff --git a/tests/expected_results/sequence_classification/ood_task.tsv b/tests/expected_results/sequence_classification/ood_task.tsv index e9e5224d..e66b7760 100644 --- a/tests/expected_results/sequence_classification/ood_task.tsv +++ b/tests/expected_results/sequence_classification/ood_task.tsv @@ -1,10 +1,10 @@ -ood/0 1 1 -ood/1 1 0 -ood/2 1 0 -ood/3 1 0 -ood/4 0 0 -ood/5 1 0 -ood/6 1 1 -ood/7 1 0 -ood/8 1 0 -ood/9 1 1 +ood/0 1 1 get curiosity pictures is this sentence in - domain or out - domain? +ood/1 1 0 play skinny love by bon iver is this sentence in - domain or out - domain? +ood/2 1 0 search for all songs with the word dance in the title. is this sentence in - domain or out - domain? +ood/3 1 0 play some psychedelic rock. is this sentence in - domain or out - domain? +ood/4 0 0 play the track grow old with me by artist chloe rose lattanzi is this sentence in - domain or out - domain? +ood/5 1 0 acdc is this sentence in - domain or out - domain? +ood/6 1 1 easy walk, accountable lady gaga. is this sentence in - domain or out - domain? +ood/7 1 0 play that would be something by eden ahbez is this sentence in - domain or out - domain? +ood/8 1 0 show me the weather in berceni, romania is this sentence in - domain or out - domain? +ood/9 1 1 what is number _ 0 + 2 is this sentence in - domain or out - domain? diff --git a/tests/expected_results/token_classification/conll2003_2.tsv b/tests/expected_results/token_classification/conll2003_2.tsv index 93b2051d..c2f9618a 100644 --- a/tests/expected_results/token_classification/conll2003_2.tsv +++ b/tests/expected_results/token_classification/conll2003_2.tsv @@ -1,5 +1,5 @@ -conll2003/0 I-PER I-LOC O B-PER B-PER B-ORG B-ORG O B-PER B-PER B-PER B-ORG B-LOC B-LOC B-LOC B-PER B-LOC O B-ORG B-LOC B-LOC B-LOC B-LOC B-PER B-ORG B-LOC B-ORG B-ORG I-ORG B-PER O B-ORG B-PER B-PER B-PER B-LOC B-LOC B-LOC B-LOC B-PER B-PER B-PER O O O O O B-ORG B-ORG B-ORG B-ORG B-ORG B-ORG B-ORG B-ORG B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O -conll2003/1 B-ORG O I-ORG O B-PER O B-PER O B-PER B-PER B-PER B-ORG B-PER B-ORG B-PER B-ORG B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-ORG B-LOC B-ORG B-ORG B-ORG B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-LOC B-LOC B-LOC B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O -conll2003/2 I-PER B-PER O I-PER O O I-PER B-ORG B-LOC B-MISC I-PER B-LOC O B-LOC I-PER B-PER B-LOC B-LOC O O I-PER I-ORG B-PER B-MISC I-PER O B-LOC I-LOC O B-LOC O O O O I-LOC B-PER B-PER O I-LOC I-LOC I-LOC I-LOC B-MISC I-MISC O O O O B-PER I-PER O O O O O O O B-ORG O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O -conll2003/3 I-PER B-LOC I-LOC B-LOC B-LOC B-ORG B-LOC B-LOC O B-ORG O B-LOC O B-PER O I-LOC B-LOC B-LOC B-LOC B-LOC B-LOC B-LOC I-LOC I-LOC B-LOC B-LOC B-LOC B-ORG B-LOC B-PER B-LOC I-PER O B-PER B-LOC O B-ORG B-PER O B-LOC O I-LOC O O O O O O O O O O O O O O O B-ORG O B-ORG O B-ORG O O O O O O B-ORG O O O O O O O O O O O O B-ORG O O -conll2003/4 I-PER O I-LOC I-PER I-PER O B-LOC B-LOC B-MISC B-LOC B-MISC I-LOC B-MISC B-LOC B-PER B-ORG B-LOC B-LOC I-PER I-PER B-PER B-LOC B-LOC I-LOC O I-PER I-PER O B-MISC O B-LOC B-ORG O I-PER O B-ORG B-LOC B-MISC O B-MISC B-MISC O O O B-ORG O O O O O O O O B-LOC I-LOC O B-ORG O O O O O O O O O O O O O O O B-LOC O O B-PER I-PER I-PER I-PER O O O O O +conll2003/0 I-PER I-LOC O B-PER B-PER B-ORG B-ORG O B-PER B-PER B-PER B-ORG B-LOC B-LOC B-LOC B-PER B-LOC O B-ORG B-LOC B-LOC B-LOC B-LOC B-PER B-ORG B-LOC B-ORG B-ORG I-ORG B-PER O B-ORG B-PER B-PER B-PER B-LOC B-LOC B-LOC B-LOC B-PER B-PER B-PER O O O O O B-ORG B-ORG B-ORG B-ORG B-ORG B-ORG B-ORG B-ORG B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY. +conll2003/1 B-ORG O I-ORG O B-PER O B-PER O B-PER B-PER B-PER B-ORG B-PER B-ORG B-PER B-ORG B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-ORG B-LOC B-ORG B-ORG B-ORG B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-PER B-LOC B-LOC B-LOC B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O LONDON 1996 - 08 - 30 +conll2003/2 I-PER B-PER O I-PER O O I-PER B-ORG B-LOC B-MISC I-PER B-LOC O B-LOC I-PER B-PER B-LOC B-LOC O O I-PER I-ORG B-PER B-MISC I-PER O B-LOC I-LOC O B-LOC O O O O I-LOC B-PER B-PER O I-LOC I-LOC I-LOC I-LOC B-MISC I-MISC O O O O B-PER I-PER O O O O O O O B-ORG O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O West Indian all - rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship. +conll2003/3 I-PER B-LOC I-LOC B-LOC B-LOC B-ORG B-LOC B-LOC O B-ORG O B-LOC O B-PER O I-LOC B-LOC B-LOC B-LOC B-LOC B-LOC B-LOC I-LOC I-LOC B-LOC B-LOC B-LOC B-ORG B-LOC B-PER B-LOC I-PER O B-PER B-LOC O B-ORG B-PER O B-LOC O I-LOC O O O O O O O O O O O O O O O B-ORG O B-ORG O B-ORG O O O O O O B-ORG O O O O O O O O O O O O B-ORG O O Their stay on top, though, may be short - lived as title rivals Essex, Derbyshire and Surrey all closed in on victory while Kent made up for lost time in their rain - affected match against Nottinghamshire. +conll2003/4 I-PER O I-LOC I-PER I-PER O B-LOC B-LOC B-MISC B-LOC B-MISC I-LOC B-MISC B-LOC B-PER B-ORG B-LOC B-LOC I-PER I-PER B-PER B-LOC B-LOC I-LOC O I-PER I-PER O B-MISC O B-LOC B-ORG O I-PER O B-ORG B-LOC B-MISC O B-MISC B-MISC O O O B-ORG O O O O O O O O B-LOC I-LOC O B-ORG O O O O O O O O O O O O O O O B-LOC O O B-PER I-PER I-PER I-PER O O O O O After bowling Somerset out for 83 on the opening morning at Grace Road, Leicestershire extended their first innings by 94 runs before being bowled out for 296 with England discard Andy Caddick taking three for 83. diff --git a/tests/expected_results/token_classification/cross_ner_news_0.tsv b/tests/expected_results/token_classification/cross_ner_news_0.tsv index c4868f62..f16b22b7 100644 --- a/tests/expected_results/token_classification/cross_ner_news_0.tsv +++ b/tests/expected_results/token_classification/cross_ner_news_0.tsv @@ -1,5 +1,5 @@ -cross_ner/0 B-researcher I-book B-scientist B-metrics I-politicalparty B-university B-scientist B-metrics B-ORG B-metrics B-metrics I-musicalinstrument I-chemicalelement B-field I-band B-metrics B-protein I-academicjournal B-university I-academicjournal I-programlang I-PER I-song B-university B-university B-university I-literarygenre B-scientist B-metrics I-university I-song B-university B-university B-protein B-organisation B-enzyme B-enzyme I-academicjournal I-musicalinstrument I-song I-song B-LOC B-university B-university I-politician B-university B-metrics B-metrics B-metrics I-academicjournal I-chemicalcompound B-album I-programlang I-programlang B-field B-award B-award B-protein I-chemicalelement B-field I-chemicalelement B-astronomicalobject I-politicalparty B-astronomicalobject B-protein I-musicalinstrument I-musicalinstrument B-university B-award I-musicalinstrument B-album I-song B-field I-programlang B-metrics I-musicalinstrument B-award B-award B-programlang B-field B-protein O O O O O O O O B-event I-event I-event O B-musicalartist B-musicalartist B-musicalartist O O O O B-location I-location I-location I-location O O B-location O O O B-location I-location O O B-location O O O O O B-location B-location B-location O O B-location O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O -cross_ner/1 I-chemicalelement B-award B-event I-politicalparty B-field I-astronomicalobject I-academicjournal B-university B-metrics B-metrics B-LOC B-university B-university I-field B-musicalinstrument I-musicalinstrument I-literarygenre B-metrics I-chemicalelement I-protein B-metrics I-person B-scientist I-politicalparty I-academicjournal I-astronomicalobject I-literarygenre I-politician I-astronomicalobject B-university B-country I-person B-politician I-event B-protein B-protein I-politicalparty I-field I-field I-politicalparty B-university B-politician B-enzyme I-task B-award B-event B-university I-field B-university I-field B-university I-field I-location I-enzyme B-university B-university B-university I-programlang I-task B-university I-programlang B-university B-university B-politician B-university B-protein B-metrics B-field B-metrics B-field B-metrics B-metrics B-university B-university B-event B-politician I-field I-field I-chemicalelement B-award B-award B-musicalartist B-musicalartist B-musicalartist O O O O O O O O O O O O O B-organisation B-organisation I-organisation O O O O O O O O O O O O O O O O O O B-album I-album O B-album I-album I-album I-album I-album I-album O B-album B-album B-album O B-album I-album O B-album I-album I-album I-album I-album O B-album I-album I-album I-album O B-album B-album B-album B-album O B-album B-album I-album I-album O O O O O O O -cross_ner/2 I-chemicalelement I-chemicalelement B-field B-metrics B-university B-field B-university I-song B-university I-programlang B-enzyme B-enzyme B-protein I-person B-chemicalelement B-chemicalelement B-university B-field B-politician I-person B-university B-university I-location B-ORG B-university B-university B-university B-university B-scientist B-university I-politicalparty B-university B-university B-university B-field I-task I-academicjournal B-university B-scientist B-university B-ORG B-university I-song B-university B-enzyme B-enzyme B-university I-politicalparty I-task I-politician I-politician B-enzyme B-university B-enzyme B-university B-enzyme I-song B-university B-enzyme I-song I-politicalparty I-politician B-scientist B-university B-country I-politician B-LOC I-task I-task I-task B-university B-university B-university B-field I-task B-LOC B-university B-politician I-song B-university B-album O O O O O O O O O O O O B-song I-song I-song I-song O O B-album I-album I-album I-album I-album O O B-song B-song O O B-misc B-misc B-misc O O O B-album B-album B-album I-album I-album O O B-song I-song I-song I-song O O B-musicalartist B-musicalartist I-musicalartist I-musicalartist I-musicalartist O O O B-song I-song I-song I-song I-song I-song I-song O O B-musicalartist I-musicalartist I-musicalartist I-musicalartist I-musicalartist O O O O B-album B-album B-album I-album I-album O O -cross_ner/3 I-chemicalelement I-chemicalelement I-person B-country B-university I-astronomicalobject I-person B-album I-academicjournal B-LOC B-country B-organisation I-song B-university B-scientist B-protein B-organisation I-song B-metrics B-LOC B-metrics B-LOC I-academicjournal B-university B-scientist I-academicjournal I-programlang I-song I-song B-scientist B-scientist B-university B-metrics I-person B-metrics B-metrics I-song I-protein B-LOC B-LOC B-LOC B-field B-field B-scientist B-scientist B-scientist B-field B-field B-LOC I-song I-song B-field B-literarygenre B-LOC B-album B-field B-field B-scientist B-album B-LOC B-LOC B-scientist B-field B-field I-song I-song I-song B-field B-LOC B-LOC B-album B-scientist B-scientist B-field B-literarygenre I-song B-LOC B-scientist B-scientist B-field I-field O O O O O B-award I-award O O B-award I-award I-award I-award I-award O O B-award I-award I-award I-award I-award I-award I-award I-award O O B-award I-award I-award O O O B-award I-award I-award I-award I-award O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O -cross_ner/4 I-chemicalelement I-song I-song I-song I-song B-enzyme I-song B-researcher B-university I-programlang I-programlang I-chemicalelement B-scientist B-university I-university B-country I-chemicalelement I-song I-enzyme B-university B-university B-university B-scientist B-organisation I-task B-field I-task I-song B-field I-politician I-politician B-country I-politician I-politician I-politician B-scientist B-scientist I-politician I-song I-band I-politician I-song I-song B-scientist B-university B-field B-university B-metrics B-chemicalelement I-song B-scientist I-MISC I-politician B-university B-scientist I-enzyme B-university I-chemicalelement I-musicalinstrument B-scientist B-scientist B-metrics B-LOC B-university B-metrics B-field B-university B-LOC B-university B-politician B-LOC B-LOC B-scientist B-scientist B-scientist B-scientist B-chemicalelement I-song I-musicalinstrument I-field B-university B-band B-band O O O O O O O O O B-musicalartist I-musicalartist I-musicalartist O B-band I-band I-band O O O O O B-song I-song I-song I-song I-song O B-musicalartist B-musicalartist B-musicalartist B-musicalartist I-musicalartist I-musicalartist I-musicalartist O B-band B-band O O B-musicalartist I-musicalartist I-musicalartist O O B-song B-song B-song B-song O B-musicalartist B-musicalartist I-musicalartist I-musicalartist I-musicalartist O B-band B-band B-band O B-musicalartist I-musicalartist O O O B-song I-song I-song I-song I-song O O O O O O O O O O +cross_ner/0 B-researcher I-book B-scientist B-metrics I-politicalparty B-university B-scientist B-metrics B-ORG B-metrics B-metrics I-musicalinstrument I-chemicalelement B-field I-band B-metrics B-protein I-academicjournal B-university I-academicjournal I-programlang I-PER I-song B-university B-university B-university I-literarygenre B-scientist B-metrics I-university I-song B-university B-university B-protein B-organisation B-enzyme B-enzyme I-academicjournal I-musicalinstrument I-song I-song B-LOC B-university B-university I-politician B-university B-metrics B-metrics B-metrics I-academicjournal I-chemicalcompound B-album I-programlang I-programlang B-field B-award B-award B-protein I-chemicalelement B-field I-chemicalelement B-astronomicalobject I-politicalparty B-astronomicalobject B-protein I-musicalinstrument I-musicalinstrument B-university B-award I-musicalinstrument B-album I-song B-field I-programlang B-metrics I-musicalinstrument B-award B-award B-programlang B-field B-protein O O O O O O O O B-event I-event I-event O B-musicalartist B-musicalartist B-musicalartist O O O O B-location I-location I-location I-location O O B-location O O O B-location I-location O O B-location O O O O O B-location B-location B-location O O B-location O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O As part of the 2010 leg of the My Christmas Tour, Bocelli gave two concerts in The O2 Arena, in London, and the Manchester Arena, in Manchester, and a concert at 3Arena, in Dublin, in late November 2010. +cross_ner/1 I-chemicalelement B-award B-event I-politicalparty B-field I-astronomicalobject I-academicjournal B-university B-metrics B-metrics B-LOC B-university B-university I-field B-musicalinstrument I-musicalinstrument I-literarygenre B-metrics I-chemicalelement I-protein B-metrics I-person B-scientist I-politicalparty I-academicjournal I-astronomicalobject I-literarygenre I-politician I-astronomicalobject B-university B-country I-person B-politician I-event B-protein B-protein I-politicalparty I-field I-field I-politicalparty B-university B-politician B-enzyme I-task B-award B-event B-university I-field B-university I-field B-university I-field I-location I-enzyme B-university B-university B-university I-programlang I-task B-university I-programlang B-university B-university B-politician B-university B-protein B-metrics B-field B-metrics B-field B-metrics B-metrics B-university B-university B-event B-politician I-field I-field I-chemicalelement B-award B-award B-musicalartist B-musicalartist B-musicalartist O O O O O O O O O O O O O B-organisation B-organisation I-organisation O O O O O O O O O O O O O O O O O O B-album I-album O B-album I-album I-album I-album I-album I-album O B-album B-album B-album O B-album I-album O B-album I-album I-album I-album I-album O B-album I-album I-album I-album O B-album B-album B-album B-album O B-album B-album I-album I-album O O O O O O O Squarepusher continues to push new boundaries to this day, where he still calls Warp Records his home, having released numerous albums to critical acclaim in the years to follow, such as Go Plastic, Do You Know Squarepusher, Ultravisitor, Hello Everything, Just a Souvenir, Solo Electric Bass 1, Ufabulum and Damogen Furies. +cross_ner/2 I-chemicalelement I-chemicalelement B-field B-metrics B-university B-field B-university I-song B-university I-programlang B-enzyme B-enzyme B-protein I-person B-chemicalelement B-chemicalelement B-university B-field B-politician I-person B-university B-university I-location B-ORG B-university B-university B-university B-university B-scientist B-university I-politicalparty B-university B-university B-university B-field I-task I-academicjournal B-university B-scientist B-university B-ORG B-university I-song B-university B-enzyme B-enzyme B-university I-politicalparty I-task I-politician I-politician B-enzyme B-university B-enzyme B-university B-enzyme I-song B-university B-enzyme I-song I-politicalparty I-politician B-scientist B-university B-country I-politician B-LOC I-task I-task I-task B-university B-university B-university B-field I-task B-LOC B-university B-politician I-song B-university B-album O O O O O O O O O O O O B-song I-song I-song I-song O O B-album I-album I-album I-album I-album O O B-song B-song O O B-misc B-misc B-misc O O O B-album B-album B-album I-album I-album O O B-song I-song I-song I-song O O B-musicalartist B-musicalartist I-musicalartist I-musicalartist I-musicalartist O O O B-song I-song I-song I-song I-song I-song I-song O O B-musicalartist I-musicalartist I-musicalartist I-musicalartist I-musicalartist O O O O B-album B-album B-album I-album I-album O O During the 1990s, many releases included recordings of classical compositions : Pictures at an Exhibition ( on Turn of the Tides ), Largo ( from Xerxes ) ( on Tyranny of Beauty ), Symphony in A Minor ( by J. S. Bach ), and Concerto in A Major / Adagio ( by Wolfgang Amadeus Mozart ) ( both on Ambient Monkeys ). +cross_ner/3 I-chemicalelement I-chemicalelement I-person B-country B-university I-astronomicalobject I-person B-album I-academicjournal B-LOC B-country B-organisation I-song B-university B-scientist B-protein B-organisation I-song B-metrics B-LOC B-metrics B-LOC I-academicjournal B-university B-scientist I-academicjournal I-programlang I-song I-song B-scientist B-scientist B-university B-metrics I-person B-metrics B-metrics I-song I-protein B-LOC B-LOC B-LOC B-field B-field B-scientist B-scientist B-scientist B-field B-field B-LOC I-song I-song B-field B-literarygenre B-LOC B-album B-field B-field B-scientist B-album B-LOC B-LOC B-scientist B-field B-field I-song I-song I-song B-field B-LOC B-LOC B-album B-scientist B-scientist B-field B-literarygenre I-song B-LOC B-scientist B-scientist B-field I-field O O O O O B-award I-award O O B-award I-award I-award I-award I-award O O B-award I-award I-award I-award I-award I-award I-award I-award O O B-award I-award I-award O O O B-award I-award I-award I-award I-award O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O He has also won three Grammy Awards, 14 Academy of Country Music awards, 11 Country Music Association ( CMA ) awards, 10 American Music Awards, and three People's Choice Awards. +cross_ner/4 I-chemicalelement I-song I-song I-song I-song B-enzyme I-song B-researcher B-university I-programlang I-programlang I-chemicalelement B-scientist B-university I-university B-country I-chemicalelement I-song I-enzyme B-university B-university B-university B-scientist B-organisation I-task B-field I-task I-song B-field I-politician I-politician B-country I-politician I-politician I-politician B-scientist B-scientist I-politician I-song I-band I-politician I-song I-song B-scientist B-university B-field B-university B-metrics B-chemicalelement I-song B-scientist I-MISC I-politician B-university B-scientist I-enzyme B-university I-chemicalelement I-musicalinstrument B-scientist B-scientist B-metrics B-LOC B-university B-metrics B-field B-university B-LOC B-university B-politician B-LOC B-LOC B-scientist B-scientist B-scientist B-scientist B-chemicalelement I-song I-musicalinstrument I-field B-university B-band B-band O O O O O O O O O B-musicalartist I-musicalartist I-musicalartist O B-band I-band I-band O O O O O B-song I-song I-song I-song I-song O B-musicalartist B-musicalartist B-musicalartist B-musicalartist I-musicalartist I-musicalartist I-musicalartist O B-band B-band O O B-musicalartist I-musicalartist I-musicalartist O O B-song B-song B-song B-song O B-musicalartist B-musicalartist I-musicalartist I-musicalartist I-musicalartist O B-band B-band B-band O B-musicalartist I-musicalartist O O O B-song I-song I-song I-song I-song O O O O O O O O O O ABBA were soon recognised and embraced by other acts : Evan Dando of the Lemonheads recorded a cover version of Knowing Me, Knowing You ; SineĢad O'Connor and Boyzone's Stephen Gately have recorded Chiquitita ; Tanita Tikaram, Blancmange and Steven Wilson paid tribute to The Day Before You Came. diff --git a/tests/expected_results/token_classification/cross_ner_news_1.tsv b/tests/expected_results/token_classification/cross_ner_news_1.tsv index 7ca866a8..728ff118 100644 --- a/tests/expected_results/token_classification/cross_ner_news_1.tsv +++ b/tests/expected_results/token_classification/cross_ner_news_1.tsv @@ -1,5 +1,5 @@ -cross_ner/0 I-chemicalelement B-university B-university I-task I-programlang B-university I-academicjournal B-university B-scientist I-academicjournal B-album B-university B-university I-academicjournal B-discipline B-metrics B-university I-politicalparty I-person B-country I-task B-award B-country I-field B-scientist I-academicjournal I-band B-country B-scientist I-programlang I-band B-LOC B-metrics I-programlang B-scientist I-field I-field I-field B-album B-album I-song B-album O O O O O B-organisation B-organisation B-organisation B-organisation B-organisation B-organisation B-organisation B-organisation B-organisation O O O O O O O O O O O O O O O O O O O O O O O O O O O O -cross_ner/1 I-programlang I-literarygenre B-protein I-programlang B-university I-task I-literarygenre I-task I-task I-programlang B-university I-field I-field I-programlang B-university B-university B-protein B-politician I-task B-astronomicalobject I-task I-song I-song I-field B-LOC I-programlang B-university B-university B-protein B-university B-university B-protein I-programlang I-task I-task I-task B-astronomicalobject I-song I-field I-field I-field I-field B-location B-location B-location B-location O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O -cross_ner/2 I-location I-field I-song B-programlang I-song I-programlang I-literarygenre I-chemicalelement I-song B-task I-misc B-metrics B-protein I-programlang I-programlang B-university B-university B-scientist I-musicalinstrument B-enzyme I-programlang I-location I-event I-field B-university B-protein I-programlang B-protein B-university I-field B-university B-university B-protein B-university B-protein I-song I-song B-scientist I-song I-chemicalelement I-chemicalelement I-musicalinstrument B-misc I-misc O O O O B-person I-person O O O O O O O B-organisation O B-organisation O O O O O O O O O O O O O O O O O O O O O O O O -cross_ner/3 I-chemicalelement I-song I-enzyme B-protein B-university I-literarygenre B-protein I-academicjournal I-discipline I-discipline B-university B-university I-enzyme B-university I-chemicalelement B-university B-university B-university I-university B-university I-astronomicalobject I-academicjournal B-university B-protein B-university B-university B-university B-protein B-university I-enzyme B-university B-protein B-university B-protein B-metrics B-protein B-university I-field I-university B-protein B-protein I-chemicalelement O O O O O O O O O O O O O O O B-organisation O B-organisation O B-organisation O O O O O O B-organisation O O O O O O O O O O O O B-organisation O O -cross_ner/4 I-chemicalelement B-protein B-scientist B-protein B-scientist I-astronomicalobject B-university B-university B-metrics I-programlang B-metrics B-literarygenre I-programlang B-university B-university I-university I-university I-book I-programlang I-musicalinstrument B-scientist I-field I-programlang B-protein B-metrics B-protein B-protein I-discipline I-person B-university I-song B-metrics I-field I-literarygenre B-metrics B-metrics I-song B-university B-musicalinstrument I-misc I-astronomicalobject B-scientist O O B-organisation O O O O O O O O B-location I-location O B-organisation O O O O O O O O O O O O O O O B-location O O B-person I-person I-person I-person O O O O O +cross_ner/0 I-chemicalelement B-university B-university I-task I-programlang B-university I-academicjournal B-university B-scientist I-academicjournal B-album B-university B-university I-academicjournal B-discipline B-metrics B-university I-politicalparty I-person B-country I-task B-award B-country I-field B-scientist I-academicjournal I-band B-country B-scientist I-programlang I-band B-LOC B-metrics I-programlang B-scientist I-field I-field I-field B-album B-album I-song B-album O O O O O B-organisation B-organisation B-organisation B-organisation B-organisation B-organisation B-organisation B-organisation B-organisation O O O O O O O O O O O O O O O O O O O O O O O O O O O O CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY. +cross_ner/1 I-programlang I-literarygenre B-protein I-programlang B-university I-task I-literarygenre I-task I-task I-programlang B-university I-field I-field I-programlang B-university B-university B-protein B-politician I-task B-astronomicalobject I-task I-song I-song I-field B-LOC I-programlang B-university B-university B-protein B-university B-university B-protein I-programlang I-task I-task I-task B-astronomicalobject I-song I-field I-field I-field I-field B-location B-location B-location B-location O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O LONDON 1996 - 08 - 30 +cross_ner/2 I-location I-field I-song B-programlang I-song I-programlang I-literarygenre I-chemicalelement I-song B-task I-misc B-metrics B-protein I-programlang I-programlang B-university B-university B-scientist I-musicalinstrument B-enzyme I-programlang I-location I-event I-field B-university B-protein I-programlang B-protein B-university I-field B-university B-university B-protein B-university B-protein I-song I-song B-scientist I-song I-chemicalelement I-chemicalelement I-musicalinstrument B-misc I-misc O O O O B-person I-person O O O O O O O B-organisation O B-organisation O O O O O O O O O O O O O O O O O O O O O O O O West Indian all - rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship. +cross_ner/3 I-chemicalelement I-song I-enzyme B-protein B-university I-literarygenre B-protein I-academicjournal I-discipline I-discipline B-university B-university I-enzyme B-university I-chemicalelement B-university B-university B-university I-university B-university I-astronomicalobject I-academicjournal B-university B-protein B-university B-university B-university B-protein B-university I-enzyme B-university B-protein B-university B-protein B-metrics B-protein B-university I-field I-university B-protein B-protein I-chemicalelement O O O O O O O O O O O O O O O B-organisation O B-organisation O B-organisation O O O O O O B-organisation O O O O O O O O O O O O B-organisation O O Their stay on top, though, may be short - lived as title rivals Essex, Derbyshire and Surrey all closed in on victory while Kent made up for lost time in their rain - affected match against Nottinghamshire. +cross_ner/4 I-chemicalelement B-protein B-scientist B-protein B-scientist I-astronomicalobject B-university B-university B-metrics I-programlang B-metrics B-literarygenre I-programlang B-university B-university I-university I-university I-book I-programlang I-musicalinstrument B-scientist I-field I-programlang B-protein B-metrics B-protein B-protein I-discipline I-person B-university I-song B-metrics I-field I-literarygenre B-metrics B-metrics I-song B-university B-musicalinstrument I-misc I-astronomicalobject B-scientist O O B-organisation O O O O O O O O B-location I-location O B-organisation O O O O O O O O O O O O O O O B-location O O B-person I-person I-person I-person O O O O O After bowling Somerset out for 83 on the opening morning at Grace Road, Leicestershire extended their first innings by 94 runs before being bowled out for 296 with England discard Andy Caddick taking three for 83. From e7db2543270e343d0f248335e21b08fd24ec845a Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 25 Feb 2022 14:12:23 -0800 Subject: [PATCH 102/102] predict: remove code for file evaluation --- genienlp/predict.py | 108 ++++++++++++-------------------------------- 1 file changed, 30 insertions(+), 78 deletions(-) diff --git a/genienlp/predict.py b/genienlp/predict.py index 2471ca4f..c92a8213 100644 --- a/genienlp/predict.py +++ b/genienlp/predict.py @@ -43,7 +43,6 @@ except RuntimeError: pass -import sys import torch @@ -63,17 +62,17 @@ set_seed, split_folder_on_disk, ) -from .validate import GenerationOutput, generate_with_model +from .validate import generate_with_model logger = logging.getLogger(__name__) def parse_argv(parser): - parser.add_argument('--path', type=str, required='--pred_file' not in sys.argv, help='Folder to load the model from') + parser.add_argument('--path', type=str, required=True, help='Folder to load the model from') parser.add_argument( '--evaluate', type=str, - required='--pred_file' not in sys.argv, + required=True, choices=['train', 'valid', 'test'], help='Which dataset to do predictions for (train, dev or test)', ) @@ -106,12 +105,6 @@ def parse_argv(parser): parser.add_argument('--cache', default='.cache', type=str, help='where to save cached files') parser.add_argument('--subsample', default=20000000, type=int, help='subsample the eval/test datasets') - parser.add_argument( - '--pred_file', - type=str, - help='If provided, we just compute evaluation metrics on this file and bypass model prediction. File should be in tsv format with id, pred, answer columns', - ) - parser.add_argument( '--pred_languages', type=str, @@ -564,20 +557,39 @@ def run(args, device): prediction_file.write('\n'.join(lines) + '\n') if len(generation_output.answers) > 0: - compute_metrics_on_file( - task_scores, - prediction_file_name, - results_file_name, - task, + metrics_to_compute = task.metrics + metrics_to_compute += args.extra_metrics + metrics_to_compute = [metric for metric in task.metrics if metric not in ['loss']] + if args.main_metric_only: + metrics_to_compute = [metrics_to_compute[0]] + metrics = calculate_and_reduce_metrics( + generation_output, + metrics_to_compute, args, tgt_lang, - confidence_scores=generation_output.confidence_scores, ) - log_final_results(args, task_scores) + with open(results_file_name, 'w' + ('' if args.overwrite else '+')) as results_file: + results_file.write(json.dumps(metrics) + '\n') + + if not args.silent: + for i, (c, p, a) in enumerate( + zip(generation_output.contexts, generation_output.predictions, generation_output.answers) + ): + log_string = ( + f'\nContext {i + 1}: {c}\nPrediction {i + 1} ({len(p)} outputs): {p}\nAnswer {i + 1}: {a}\n' + ) + if args.calibrator_paths is not None: + log_string += f'Confidence {i + 1} : ' + for score in generation_output.confidence_scores: + log_string += f'{score[i]:.3f}, ' + log_string += '\n' + logger.info(log_string) + logger.info(metrics) + + task_scores[task].append((len(generation_output.answers), metrics[task.metrics[0]])) -def log_final_results(args, task_scores): decaScore = [] for task in task_scores.keys(): decaScore.append( @@ -592,55 +604,6 @@ def log_final_results(args, task_scores): logger.info(f'\nSummary: | {sum(decaScore)} | {" | ".join([str(x) for x in decaScore])} |\n') -def compute_metrics_on_file(task_scores, pred_file, results_file_name, task, args, tgt_lang, confidence_scores=None): - generation_output = GenerationOutput() - ids, contexts, preds, targets = [], [], [], [] - with open(pred_file) as fin: - for line in fin: - id_, *pred, target, context = line.strip('\n').split('\t') - ids.append(id_) - contexts.append(context) - preds.append(pred) - targets.append(target) - - generation_output.example_ids = ids - generation_output.contexts = contexts - generation_output.predictions = preds - generation_output.answers = targets - generation_output.confidence_scores = confidence_scores - - metrics_to_compute = task.metrics - metrics_to_compute += args.extra_metrics - metrics_to_compute = [metric for metric in task.metrics if metric not in ['loss']] - if args.main_metric_only: - metrics_to_compute = [metrics_to_compute[0]] - metrics = calculate_and_reduce_metrics( - generation_output, - metrics_to_compute, - args, - tgt_lang, - ) - - with open(results_file_name, 'w' + ('' if args.overwrite else '+')) as results_file: - results_file.write(json.dumps(metrics) + '\n') - - if not args.silent: - for i, (c, p, a) in enumerate( - zip(generation_output.contexts, generation_output.predictions, generation_output.answers) - ): - log_string = f'\nContext {i + 1}: {c}\nPrediction {i + 1} ({len(p)} outputs): {p}\nAnswer {i + 1}: {a}\n' - if args.calibrator_paths is not None: - log_string += f'Confidence {i + 1} : ' - for score in generation_output.confidence_scores: - log_string += f'{score[i]:.3f}, ' - log_string += '\n' - logger.info(log_string) - - logger.info(metrics) - - task_scores[task].append((len(generation_output.answers), metrics[task.metrics[0]])) - - def main(args): load_config_json(args) check_and_update_generation_args(args) @@ -667,17 +630,6 @@ def main(args): task.metrics = new_metrics - if args.pred_file and os.path.exists(args.pred_file): - task_scores = defaultdict(list) - eval_dir = os.path.join(args.eval_dir, args.evaluate) - os.makedirs(eval_dir, exist_ok=True) - tgt_lang = args.pred_tgt_languages[0] - for task in args.tasks: - results_file_name = os.path.join(eval_dir, task.name + '.results.json') - compute_metrics_on_file(task_scores, args.pred_file, results_file_name, task, args, tgt_lang) - log_final_results(args, task_scores) - return - logger.info(f'Loading from {args.best_checkpoint}') devices = get_devices(args.devices)