Skip to content

Commit

Permalink
Fix target_train_data_fraction overriding pretrain_data_fraction (#1070)
Browse files Browse the repository at this point in the history
* add private instance generator field to Task w/ getter and setter

* update build_tasks to use new instance generator Task field setter

* update trainer to get phase-appropriate instance generators

* update evaluate to use new Task instance generator getter

* update pred writing task to use new instance generator getter/setter

* add tests for new Task instance generator functionality
  • Loading branch information
pyeres authored Apr 23, 2020
1 parent bc786ad commit 14fae87
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 147 deletions.
3 changes: 1 addition & 2 deletions jiant/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def evaluate(
n_task_examples = 0
task_preds = [] # accumulate DataFrames
assert split in ["train", "val", "test"]
dataset = getattr(task, "%s_data" % split)
generator = iterator(dataset, num_epochs=1, shuffle=False)
generator = iterator(task.get_instance_iterable(split), num_epochs=1, shuffle=False)
for batch_idx, batch in enumerate(generator):
with torch.no_grad():
if isinstance(cuda_device, int):
Expand Down
26 changes: 20 additions & 6 deletions jiant/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,22 +416,36 @@ def build_tasks(
target_tasks = []
for task in tasks:
# Replace lists of instances with lazy generators from disk.
task.val_data = _get_instance_generator(task.name, "val", preproc_dir)
task.test_data = _get_instance_generator(task.name, "test", preproc_dir)
task.set_instance_iterable(
split_name="val",
instance_iterable=_get_instance_generator(task.name, "val", preproc_dir),
)
task.set_instance_iterable(
split_name="test",
instance_iterable=_get_instance_generator(task.name, "test", preproc_dir),
)
# When using pretrain_data_fraction, we need modified iterators for use
# only on training datasets at pretraining time.
if task.name in pretrain_task_names:
log.info("\tCreating trimmed pretraining-only version of " + task.name + " train.")
task.train_data = _get_instance_generator(
task.name, "train", preproc_dir, fraction=args.pretrain_data_fraction
task.set_instance_iterable(
split_name="train",
instance_iterable=_get_instance_generator(
task.name, "train", preproc_dir, fraction=args.pretrain_data_fraction
),
phase="pretrain",
)
pretrain_tasks.append(task)
# When using target_train_data_fraction, we need modified iterators
# only for training datasets at do_target_task_training time.
if task.name in target_task_names:
log.info("\tCreating trimmed target-only version of " + task.name + " train.")
task.train_data = _get_instance_generator(
task.name, "train", preproc_dir, fraction=args.target_train_data_fraction
task.set_instance_iterable(
split_name="train",
instance_iterable=_get_instance_generator(
task.name, "train", preproc_dir, fraction=args.target_train_data_fraction
),
phase="target_train",
)
target_tasks.append(task)

Expand Down
39 changes: 38 additions & 1 deletion jiant/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging as log
import os
from typing import Any, Dict, Iterable, List, Sequence, Type
from typing import Any, Dict, Iterable, List, Sequence, Type, Union, Generator

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -35,6 +35,7 @@
tokenize_and_truncate,
load_pair_nli_jsonl,
)
from jiant.utils.serialize import RepeatableIterator
from jiant.utils.tokenizers import get_tokenizer
from jiant.utils.retokenize import get_aligner_fn
from jiant.tasks.registry import register_task # global task registry
Expand Down Expand Up @@ -228,6 +229,7 @@ def __init__(self, name, tokenizer_name):
self.sentences = None
self.example_counts = None
self.contributes_to_aggregate_score = True
self._instance_iterables = {}

def load_data(self):
""" Load data from path and create splits. """
Expand Down Expand Up @@ -293,6 +295,41 @@ def handle_preds(self, preds, batch):
"""
return preds

def set_instance_iterable(
self, split_name: str, instance_iterable: Iterable, phase: str = None
):
"""Takes a data instance iterable and stores it in a private field of this Task instance
Parameters
----------
split_name : string
instance_iterable : Iterable
phase : str
"""
self._instance_iterables[(split_name, phase)] = instance_iterable

def get_instance_iterable(
self, split_name: str, phase: str = None
) -> Union[RepeatableIterator, Generator]:
"""Returns an instance iterable for the specified split name and phase.
Parameters
----------
split_name : string
phase : string
Returns
-------
Union[RepeatableIterator, Generator]
"""
if not self._instance_iterables:
raise ValueError("set_instance_iterable must be called before get_instance_iterable")
if split_name == "train" and phase is None:
raise ValueError("phase must be specified to get relevant training data")
return self._instance_iterables[(split_name, phase)]


class ClassificationTask(Task):
""" General classification task """
Expand Down
14 changes: 10 additions & 4 deletions jiant/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,12 @@ def _setup_training(
):
os.mkdir(os.path.join(self._serialization_dir, task.name))

# Adding task-specific smart iterator to speed up training
instance = [i for i in itertools.islice(task.train_data, 1)][0]
instance = [
i
for i in itertools.islice(
task.get_instance_iterable(split_name="train", phase=phase), 1
)
][0]
pad_dict = instance.get_padding_lengths()
sorting_keys = []
for field in pad_dict:
Expand All @@ -335,7 +339,9 @@ def _setup_training(
biggest_batch_first=True,
)
task_info["iterator"] = iterator
task_info["tr_generator"] = iterator(task.train_data, num_epochs=None)
task_info["tr_generator"] = iterator(
task.get_instance_iterable(split_name="train", phase=phase), num_epochs=None
)

n_training_examples = task.n_train_examples
# Warning: This won't be precise when training_data_fraction is set, since each
Expand Down Expand Up @@ -833,7 +839,7 @@ def _calculate_validation_performance(
else:
max_data_points = task.n_val_examples
val_generator = BasicIterator(batch_size, instances_per_epoch=max_data_points)(
task.val_data, num_epochs=1, shuffle=False
task.get_instance_iterable(split_name="val"), num_epochs=1, shuffle=False
)
n_val_batches = math.ceil(max_data_points / batch_size)
all_val_metrics["%s_loss" % task.name] = 0.0
Expand Down
58 changes: 43 additions & 15 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,49 @@
import logging
import unittest

from jiant.tasks import Task
from jiant.tasks.registry import REGISTRY


def test_instantiate_all_tasks():
"""
All tasks should be able to be instantiated without needing to access actual data
class TestTasks(unittest.TestCase):
def test_instantiate_all_tasks(self):
"""
All tasks should be able to be instantiated without needing to access actual data
Test may change if task instantiation signature changes.
"""
logger = logging.getLogger()
logger.setLevel(level=logging.CRITICAL)
for name, (cls, _, kw) in REGISTRY.items():
cls(
"dummy_path",
max_seq_len=1,
name="dummy_name",
tokenizer_name="dummy_tokenizer_name",
**kw,
)
Test may change if task instantiation signature changes.
"""
logger = logging.getLogger()
logger.setLevel(level=logging.CRITICAL)
for name, (cls, _, kw) in REGISTRY.items():
cls(
"dummy_path",
max_seq_len=1,
name="dummy_name",
tokenizer_name="dummy_tokenizer_name",
**kw,
)

def test_tasks_get_train_instance_iterable_without_phase(self):
task = Task(name="dummy_name", tokenizer_name="dummy_tokenizer_name")
train_iterable = [1, 2, 3]
task.set_instance_iterable("train", train_iterable, "target_train")
self.assertRaises(ValueError, task.get_instance_iterable, "train")

def test_tasks_set_and_get_instance_iterables(self):
task = Task(name="dummy_name", tokenizer_name="dummy_tokenizer_name")
val_iterable = [1, 2, 3]
test_iterable = [4, 5, 6]
train_pretrain_iterable = [7, 8]
train_target_train_iterable = [9]
task.set_instance_iterable("val", val_iterable)
task.set_instance_iterable("test", test_iterable)
task.set_instance_iterable("train", train_pretrain_iterable, "pretrain")
task.set_instance_iterable("train", train_target_train_iterable, "target_train")
retreived_val_iterable = task.get_instance_iterable("val")
retreived_test_iterable = task.get_instance_iterable("test")
retreived_train_pretrain_iterable = task.get_instance_iterable("train", "pretrain")
retreived_train_target_iterable = task.get_instance_iterable("train", "target_train")
self.assertListEqual(val_iterable, retreived_val_iterable)
self.assertListEqual(test_iterable, retreived_test_iterable)
self.assertListEqual(train_pretrain_iterable, retreived_train_pretrain_iterable)
self.assertListEqual(train_target_train_iterable, retreived_train_target_iterable)
Loading

0 comments on commit 14fae87

Please sign in to comment.