diff --git a/jiant/evaluate.py b/jiant/evaluate.py index 7932b54d0..03b927231 100644 --- a/jiant/evaluate.py +++ b/jiant/evaluate.py @@ -95,8 +95,7 @@ def evaluate( n_task_examples = 0 task_preds = [] # accumulate DataFrames assert split in ["train", "val", "test"] - dataset = getattr(task, "%s_data" % split) - generator = iterator(dataset, num_epochs=1, shuffle=False) + generator = iterator(task.get_instance_iterable(split), num_epochs=1, shuffle=False) for batch_idx, batch in enumerate(generator): with torch.no_grad(): if isinstance(cuda_device, int): diff --git a/jiant/preprocess.py b/jiant/preprocess.py index 768ccfa5b..4e08dd00d 100644 --- a/jiant/preprocess.py +++ b/jiant/preprocess.py @@ -416,22 +416,36 @@ def build_tasks( target_tasks = [] for task in tasks: # Replace lists of instances with lazy generators from disk. - task.val_data = _get_instance_generator(task.name, "val", preproc_dir) - task.test_data = _get_instance_generator(task.name, "test", preproc_dir) + task.set_instance_iterable( + split_name="val", + instance_iterable=_get_instance_generator(task.name, "val", preproc_dir), + ) + task.set_instance_iterable( + split_name="test", + instance_iterable=_get_instance_generator(task.name, "test", preproc_dir), + ) # When using pretrain_data_fraction, we need modified iterators for use # only on training datasets at pretraining time. if task.name in pretrain_task_names: log.info("\tCreating trimmed pretraining-only version of " + task.name + " train.") - task.train_data = _get_instance_generator( - task.name, "train", preproc_dir, fraction=args.pretrain_data_fraction + task.set_instance_iterable( + split_name="train", + instance_iterable=_get_instance_generator( + task.name, "train", preproc_dir, fraction=args.pretrain_data_fraction + ), + phase="pretrain", ) pretrain_tasks.append(task) # When using target_train_data_fraction, we need modified iterators # only for training datasets at do_target_task_training time. if task.name in target_task_names: log.info("\tCreating trimmed target-only version of " + task.name + " train.") - task.train_data = _get_instance_generator( - task.name, "train", preproc_dir, fraction=args.target_train_data_fraction + task.set_instance_iterable( + split_name="train", + instance_iterable=_get_instance_generator( + task.name, "train", preproc_dir, fraction=args.target_train_data_fraction + ), + phase="target_train", ) target_tasks.append(task) diff --git a/jiant/tasks/tasks.py b/jiant/tasks/tasks.py index 0ea652412..c4203974a 100644 --- a/jiant/tasks/tasks.py +++ b/jiant/tasks/tasks.py @@ -3,7 +3,7 @@ import json import logging as log import os -from typing import Any, Dict, Iterable, List, Sequence, Type +from typing import Any, Dict, Iterable, List, Sequence, Type, Union, Generator import numpy as np import pandas as pd @@ -35,6 +35,7 @@ tokenize_and_truncate, load_pair_nli_jsonl, ) +from jiant.utils.serialize import RepeatableIterator from jiant.utils.tokenizers import get_tokenizer from jiant.utils.retokenize import get_aligner_fn from jiant.tasks.registry import register_task # global task registry @@ -228,6 +229,7 @@ def __init__(self, name, tokenizer_name): self.sentences = None self.example_counts = None self.contributes_to_aggregate_score = True + self._instance_iterables = {} def load_data(self): """ Load data from path and create splits. """ @@ -293,6 +295,41 @@ def handle_preds(self, preds, batch): """ return preds + def set_instance_iterable( + self, split_name: str, instance_iterable: Iterable, phase: str = None + ): + """Takes a data instance iterable and stores it in a private field of this Task instance + + Parameters + ---------- + split_name : string + instance_iterable : Iterable + phase : str + + """ + self._instance_iterables[(split_name, phase)] = instance_iterable + + def get_instance_iterable( + self, split_name: str, phase: str = None + ) -> Union[RepeatableIterator, Generator]: + """Returns an instance iterable for the specified split name and phase. + + Parameters + ---------- + split_name : string + phase : string + + Returns + ------- + Union[RepeatableIterator, Generator] + + """ + if not self._instance_iterables: + raise ValueError("set_instance_iterable must be called before get_instance_iterable") + if split_name == "train" and phase is None: + raise ValueError("phase must be specified to get relevant training data") + return self._instance_iterables[(split_name, phase)] + class ClassificationTask(Task): """ General classification task """ diff --git a/jiant/trainer.py b/jiant/trainer.py index 6ec80b2c5..4717aeb6b 100644 --- a/jiant/trainer.py +++ b/jiant/trainer.py @@ -321,8 +321,12 @@ def _setup_training( ): os.mkdir(os.path.join(self._serialization_dir, task.name)) - # Adding task-specific smart iterator to speed up training - instance = [i for i in itertools.islice(task.train_data, 1)][0] + instance = [ + i + for i in itertools.islice( + task.get_instance_iterable(split_name="train", phase=phase), 1 + ) + ][0] pad_dict = instance.get_padding_lengths() sorting_keys = [] for field in pad_dict: @@ -335,7 +339,9 @@ def _setup_training( biggest_batch_first=True, ) task_info["iterator"] = iterator - task_info["tr_generator"] = iterator(task.train_data, num_epochs=None) + task_info["tr_generator"] = iterator( + task.get_instance_iterable(split_name="train", phase=phase), num_epochs=None + ) n_training_examples = task.n_train_examples # Warning: This won't be precise when training_data_fraction is set, since each @@ -833,7 +839,7 @@ def _calculate_validation_performance( else: max_data_points = task.n_val_examples val_generator = BasicIterator(batch_size, instances_per_epoch=max_data_points)( - task.val_data, num_epochs=1, shuffle=False + task.get_instance_iterable(split_name="val"), num_epochs=1, shuffle=False ) n_val_batches = math.ceil(max_data_points / batch_size) all_val_metrics["%s_loss" % task.name] = 0.0 diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index d73d35561..47a24dffd 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -1,21 +1,49 @@ import logging +import unittest +from jiant.tasks import Task from jiant.tasks.registry import REGISTRY -def test_instantiate_all_tasks(): - """ - All tasks should be able to be instantiated without needing to access actual data +class TestTasks(unittest.TestCase): + def test_instantiate_all_tasks(self): + """ + All tasks should be able to be instantiated without needing to access actual data - Test may change if task instantiation signature changes. - """ - logger = logging.getLogger() - logger.setLevel(level=logging.CRITICAL) - for name, (cls, _, kw) in REGISTRY.items(): - cls( - "dummy_path", - max_seq_len=1, - name="dummy_name", - tokenizer_name="dummy_tokenizer_name", - **kw, - ) + Test may change if task instantiation signature changes. + """ + logger = logging.getLogger() + logger.setLevel(level=logging.CRITICAL) + for name, (cls, _, kw) in REGISTRY.items(): + cls( + "dummy_path", + max_seq_len=1, + name="dummy_name", + tokenizer_name="dummy_tokenizer_name", + **kw, + ) + + def test_tasks_get_train_instance_iterable_without_phase(self): + task = Task(name="dummy_name", tokenizer_name="dummy_tokenizer_name") + train_iterable = [1, 2, 3] + task.set_instance_iterable("train", train_iterable, "target_train") + self.assertRaises(ValueError, task.get_instance_iterable, "train") + + def test_tasks_set_and_get_instance_iterables(self): + task = Task(name="dummy_name", tokenizer_name="dummy_tokenizer_name") + val_iterable = [1, 2, 3] + test_iterable = [4, 5, 6] + train_pretrain_iterable = [7, 8] + train_target_train_iterable = [9] + task.set_instance_iterable("val", val_iterable) + task.set_instance_iterable("test", test_iterable) + task.set_instance_iterable("train", train_pretrain_iterable, "pretrain") + task.set_instance_iterable("train", train_target_train_iterable, "target_train") + retreived_val_iterable = task.get_instance_iterable("val") + retreived_test_iterable = task.get_instance_iterable("test") + retreived_train_pretrain_iterable = task.get_instance_iterable("train", "pretrain") + retreived_train_target_iterable = task.get_instance_iterable("train", "target_train") + self.assertListEqual(val_iterable, retreived_val_iterable) + self.assertListEqual(test_iterable, retreived_test_iterable) + self.assertListEqual(train_pretrain_iterable, retreived_train_pretrain_iterable) + self.assertListEqual(train_target_train_iterable, retreived_train_target_iterable) diff --git a/tests/test_write_preds.py b/tests/test_write_preds.py index f61086a52..794349e24 100644 --- a/tests/test_write_preds.py +++ b/tests/test_write_preds.py @@ -90,127 +90,130 @@ def setUp(self): ] ) indexers = {"bert_cased": SingleIdTokenIndexer("bert-xe-cased")} - self.wic.val_data = [ - Instance( - { - "sent1_str": MetadataField("Room and board."), - "sent2_str": MetadataField("He nailed boards"), - "idx": LabelField(0, skip_indexing=True), - "idx2": NumericField(2), - "idx1": NumericField(3), - "inputs": self.sentence_to_text_field( - [ - "[CLS]", - "Room", - "and", - "Board", - ".", - "[SEP]", - "He", - "nailed", - "boards", - "[SEP]", - ], - indexers, - ), - "labels": LabelField(0, skip_indexing=1), - } - ), - Instance( - { - "sent1_str": MetadataField("C ##ir ##culate a rumor ."), - "sent2_str": MetadataField("This letter is being circulated"), - "idx": LabelField(1, skip_indexing=True), - "idx2": NumericField(2), - "idx1": NumericField(3), - "inputs": self.sentence_to_text_field( - [ - "[CLS]", - "C", - "##ir", - "##culate", - "a", - "rumor", - "[SEP]", - "This", - "##let", - "##ter", - "is", - "being", - "c", - "##ir", - "##culated", - "[SEP]", - ], - indexers, - ), - "labels": LabelField(0, skip_indexing=1), - } - ), - Instance( - { - "sent1_str": MetadataField("Hook a fish'"), - "sent2_str": MetadataField("He hooked a snake accidentally"), - "idx": LabelField(2, skip_indexing=True), - "idx2": NumericField(2), - "idx1": NumericField(3), - "inputs": self.sentence_to_text_field( - [ - "[CLS]", - "Hook", - "a", - "fish", - "[SEP]", - "He", - "hooked", - "a", - "snake", - "accidentally", - "[SEP]", - ], - indexers, - ), - "labels": LabelField(1, skip_indexing=1), - } - ), - Instance( - { - "sent1_str": MetadataField("For recreation he wrote poetry."), - "sent2_str": MetadataField("Drug abuse is often regarded as recreation ."), - "idx": LabelField(3, skip_indexing=True), - "idx2": NumericField(2), - "idx1": NumericField(3), - "inputs": self.sentence_to_text_field( - [ - "[CLS]", - "For", - "re", - "##creation", - "he", - "wrote", - "poetry", - "[SEP]", - "Drug", - "abuse", - "is", - "often", - "re", - "##garded", - "as", - "re", - "##creation", - "[SEP]", - ], - indexers, - ), - "labels": LabelField(1, skip_indexing=1), - } - ), - ] + self.wic.set_instance_iterable( + "val", + [ + Instance( + { + "sent1_str": MetadataField("Room and board."), + "sent2_str": MetadataField("He nailed boards"), + "idx": LabelField(0, skip_indexing=True), + "idx2": NumericField(2), + "idx1": NumericField(3), + "inputs": self.sentence_to_text_field( + [ + "[CLS]", + "Room", + "and", + "Board", + ".", + "[SEP]", + "He", + "nailed", + "boards", + "[SEP]", + ], + indexers, + ), + "labels": LabelField(0, skip_indexing=1), + } + ), + Instance( + { + "sent1_str": MetadataField("C ##ir ##culate a rumor ."), + "sent2_str": MetadataField("This letter is being circulated"), + "idx": LabelField(1, skip_indexing=True), + "idx2": NumericField(2), + "idx1": NumericField(3), + "inputs": self.sentence_to_text_field( + [ + "[CLS]", + "C", + "##ir", + "##culate", + "a", + "rumor", + "[SEP]", + "This", + "##let", + "##ter", + "is", + "being", + "c", + "##ir", + "##culated", + "[SEP]", + ], + indexers, + ), + "labels": LabelField(0, skip_indexing=1), + } + ), + Instance( + { + "sent1_str": MetadataField("Hook a fish'"), + "sent2_str": MetadataField("He hooked a snake accidentally"), + "idx": LabelField(2, skip_indexing=True), + "idx2": NumericField(2), + "idx1": NumericField(3), + "inputs": self.sentence_to_text_field( + [ + "[CLS]", + "Hook", + "a", + "fish", + "[SEP]", + "He", + "hooked", + "a", + "snake", + "accidentally", + "[SEP]", + ], + indexers, + ), + "labels": LabelField(1, skip_indexing=1), + } + ), + Instance( + { + "sent1_str": MetadataField("For recreation he wrote poetry."), + "sent2_str": MetadataField("Drug abuse is often regarded as recreation ."), + "idx": LabelField(3, skip_indexing=True), + "idx2": NumericField(2), + "idx1": NumericField(3), + "inputs": self.sentence_to_text_field( + [ + "[CLS]", + "For", + "re", + "##creation", + "he", + "wrote", + "poetry", + "[SEP]", + "Drug", + "abuse", + "is", + "often", + "re", + "##garded", + "as", + "re", + "##creation", + "[SEP]", + ], + indexers, + ), + "labels": LabelField(1, skip_indexing=1), + } + ), + ], + ) self.val_preds = {"sts-b": stsb_val_preds, "wic": wic_val_preds} - self.vocab = vocabulary.Vocabulary.from_instances(self.wic.val_data) + self.vocab = vocabulary.Vocabulary.from_instances(self.wic.get_instance_iterable("val")) self.vocab.add_token_to_namespace("True", "wic_tags") - for data in self.wic.val_data: + for data in self.wic.get_instance_iterable("val"): data.index_fields(self.vocab) self.glue_tasks = [self.stsb, self.wic] self.args = mock.Mock()