From d858734c091c2ee4e6447875c56a263116ec474c Mon Sep 17 00:00:00 2001 From: Phil Yeres Date: Tue, 25 Feb 2020 09:51:59 -0500 Subject: [PATCH] Minor refactor for task structure consistency and clean up in preprocessing (#1004) * change task data field *_data -> *_data_text for consistency --- jiant/preprocess.py | 5 +--- jiant/tasks/qa.py | 56 ++++++++++++++++++++++---------------------- jiant/tasks/tasks.py | 2 +- 3 files changed, 30 insertions(+), 33 deletions(-) diff --git a/jiant/preprocess.py b/jiant/preprocess.py index d8eb33cf1..27e3b6e18 100644 --- a/jiant/preprocess.py +++ b/jiant/preprocess.py @@ -405,10 +405,7 @@ def build_tasks( ) # Delete in-memory data - we'll lazy-load from disk later. - # TODO: delete task.{split}_data_text as well? - task.train_data = None - task.val_data = None - task.test_data = None + # TODO: delete task.{split}_data_text? log.info("\tFinished indexing tasks") diff --git a/jiant/tasks/qa.py b/jiant/tasks/qa.py index df9772b76..24da6d68f 100644 --- a/jiant/tasks/qa.py +++ b/jiant/tasks/qa.py @@ -400,9 +400,9 @@ def __init__(self, path, max_seq_len, name, **kw): self.path = path self.max_seq_len = max_seq_len - self.train_data = None - self.val_data = None - self.test_data = None + self.train_data_text = None + self.val_data_text = None + self.test_data_text = None self.f1_metric = F1SpanMetric() self.em_metric = ExactMatchSpanMetric() @@ -421,25 +421,25 @@ def get_metrics(self, reset: bool = False) -> Dict: return collected_metrics def load_data(self): - self.train_data = self._load_file(os.path.join(self.path, "orig", "train.jsonl.gz")) + self.train_data_text = self._load_file(os.path.join(self.path, "orig", "train.jsonl.gz")) # Shuffle val_data to ensure diversity in periodic validation with val_data_limit - self.val_data = self._load_file( + self.val_data_text = self._load_file( os.path.join(self.path, "orig", "dev.jsonl.gz"), shuffle=True ) - self.test_data = self._load_file(os.path.join(self.path, "orig", "test.jsonl.gz")) + self.test_data_text = self._load_file(os.path.join(self.path, "orig", "test.jsonl.gz")) self.sentences = ( - [example["passage"] for example in self.train_data] - + [example["question"] for example in self.train_data] - + [example["passage"] for example in self.val_data] - + [example["question"] for example in self.val_data] + [example["passage"] for example in self.train_data_text] + + [example["question"] for example in self.train_data_text] + + [example["passage"] for example in self.val_data_text] + + [example["question"] for example in self.val_data_text] ) self.example_counts = { - "train": len(self.train_data), - "val": len(self.val_data), - "test": len(self.test_data), + "train": len(self.train_data_text), + "val": len(self.val_data_text), + "test": len(self.test_data_text), } def get_sentences(self) -> Iterable[Sequence[str]]: @@ -582,9 +582,9 @@ def __init__(self, path, max_seq_len, name="qamr", **kw): super(QAMRTask, self).__init__(name, **kw) self.max_seq_len = max_seq_len - self.train_data = None - self.val_data = None - self.test_data = None + self.train_data_text = None + self.val_data_text = None + self.test_data_text = None self.f1_metric = F1SpanMetric() self.em_metric = ExactMatchSpanMetric() @@ -708,33 +708,33 @@ def load_wiki_dict(cls, path): def load_data(self): wiki_dict = self.load_wiki_dict(os.path.join(self.path, "qamr/data/wiki-sentences.tsv")) - self.train_data = self.process_dataset( + self.train_data_text = self.process_dataset( self.load_tsv_dataset( path=os.path.join(self.path, "qamr/data/filtered/train.tsv"), wiki_dict=wiki_dict ) ) - self.val_data = self.process_dataset( + self.val_data_text = self.process_dataset( self.load_tsv_dataset( path=os.path.join(self.path, "qamr/data/filtered/dev.tsv"), wiki_dict=wiki_dict ), shuffle=True, ) - self.test_data = self.process_dataset( + self.test_data_text = self.process_dataset( self.load_tsv_dataset( path=os.path.join(self.path, "qamr/data/filtered/test.tsv"), wiki_dict=wiki_dict ) ) self.sentences = ( - [example["passage"] for example in self.train_data] - + [example["question"] for example in self.train_data] - + [example["passage"] for example in self.val_data] - + [example["question"] for example in self.val_data] + [example["passage"] for example in self.train_data_text] + + [example["question"] for example in self.train_data_text] + + [example["passage"] for example in self.val_data_text] + + [example["question"] for example in self.val_data_text] ) self.example_counts = { - "train": len(self.train_data), - "val": len(self.val_data), - "test": len(self.test_data), + "train": len(self.train_data_text), + "val": len(self.val_data_text), + "test": len(self.test_data_text), } @staticmethod @@ -765,8 +765,8 @@ def collapse_contiguous_indices(ls): def remap_ptb_passage_and_answer_spans(ptb_tokens, answer_span, moses, tokenizer_name): # Start with PTB tokenized tokens - # The answer_span is also in ptb_token space. We first want to detokenize, and convert everything to - # space-tokenization space. + # The answer_span is also in ptb_token space. We first want to detokenize, and convert + # everything to space-tokenization space. # Detokenize the passage. Everything we do will be based on the detokenized input, # INCLUDING evaluation. diff --git a/jiant/tasks/tasks.py b/jiant/tasks/tasks.py index 6c677c73d..a21a6cb34 100644 --- a/jiant/tasks/tasks.py +++ b/jiant/tasks/tasks.py @@ -1798,7 +1798,7 @@ def __init__(self, path, max_seq_len, name, n_classes, **kw): self.train_data_text = None self.val_data_text = None - self.test_data = None + self.test_data_text = None self.acc_scorer = BooleanAccuracy() self.gender_parity_scorer = GenderParity() self.val_metric = "%s_accuracy" % name