Skip to content

Commit

Permalink
Minor refactor for task structure consistency and clean up in preproc…
Browse files Browse the repository at this point in the history
…essing (#1004)

* change task data field *_data -> *_data_text for consistency
  • Loading branch information
pyeres authored Feb 25, 2020
1 parent 19c8b3a commit d858734
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 33 deletions.
5 changes: 1 addition & 4 deletions jiant/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,7 @@ def build_tasks(
)

# Delete in-memory data - we'll lazy-load from disk later.
# TODO: delete task.{split}_data_text as well?
task.train_data = None
task.val_data = None
task.test_data = None
# TODO: delete task.{split}_data_text?

log.info("\tFinished indexing tasks")

Expand Down
56 changes: 28 additions & 28 deletions jiant/tasks/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,9 @@ def __init__(self, path, max_seq_len, name, **kw):
self.path = path
self.max_seq_len = max_seq_len

self.train_data = None
self.val_data = None
self.test_data = None
self.train_data_text = None
self.val_data_text = None
self.test_data_text = None

self.f1_metric = F1SpanMetric()
self.em_metric = ExactMatchSpanMetric()
Expand All @@ -421,25 +421,25 @@ def get_metrics(self, reset: bool = False) -> Dict:
return collected_metrics

def load_data(self):
self.train_data = self._load_file(os.path.join(self.path, "orig", "train.jsonl.gz"))
self.train_data_text = self._load_file(os.path.join(self.path, "orig", "train.jsonl.gz"))

# Shuffle val_data to ensure diversity in periodic validation with val_data_limit
self.val_data = self._load_file(
self.val_data_text = self._load_file(
os.path.join(self.path, "orig", "dev.jsonl.gz"), shuffle=True
)

self.test_data = self._load_file(os.path.join(self.path, "orig", "test.jsonl.gz"))
self.test_data_text = self._load_file(os.path.join(self.path, "orig", "test.jsonl.gz"))

self.sentences = (
[example["passage"] for example in self.train_data]
+ [example["question"] for example in self.train_data]
+ [example["passage"] for example in self.val_data]
+ [example["question"] for example in self.val_data]
[example["passage"] for example in self.train_data_text]
+ [example["question"] for example in self.train_data_text]
+ [example["passage"] for example in self.val_data_text]
+ [example["question"] for example in self.val_data_text]
)
self.example_counts = {
"train": len(self.train_data),
"val": len(self.val_data),
"test": len(self.test_data),
"train": len(self.train_data_text),
"val": len(self.val_data_text),
"test": len(self.test_data_text),
}

def get_sentences(self) -> Iterable[Sequence[str]]:
Expand Down Expand Up @@ -582,9 +582,9 @@ def __init__(self, path, max_seq_len, name="qamr", **kw):
super(QAMRTask, self).__init__(name, **kw)
self.max_seq_len = max_seq_len

self.train_data = None
self.val_data = None
self.test_data = None
self.train_data_text = None
self.val_data_text = None
self.test_data_text = None

self.f1_metric = F1SpanMetric()
self.em_metric = ExactMatchSpanMetric()
Expand Down Expand Up @@ -708,33 +708,33 @@ def load_wiki_dict(cls, path):

def load_data(self):
wiki_dict = self.load_wiki_dict(os.path.join(self.path, "qamr/data/wiki-sentences.tsv"))
self.train_data = self.process_dataset(
self.train_data_text = self.process_dataset(
self.load_tsv_dataset(
path=os.path.join(self.path, "qamr/data/filtered/train.tsv"), wiki_dict=wiki_dict
)
)
self.val_data = self.process_dataset(
self.val_data_text = self.process_dataset(
self.load_tsv_dataset(
path=os.path.join(self.path, "qamr/data/filtered/dev.tsv"), wiki_dict=wiki_dict
),
shuffle=True,
)
self.test_data = self.process_dataset(
self.test_data_text = self.process_dataset(
self.load_tsv_dataset(
path=os.path.join(self.path, "qamr/data/filtered/test.tsv"), wiki_dict=wiki_dict
)
)

self.sentences = (
[example["passage"] for example in self.train_data]
+ [example["question"] for example in self.train_data]
+ [example["passage"] for example in self.val_data]
+ [example["question"] for example in self.val_data]
[example["passage"] for example in self.train_data_text]
+ [example["question"] for example in self.train_data_text]
+ [example["passage"] for example in self.val_data_text]
+ [example["question"] for example in self.val_data_text]
)
self.example_counts = {
"train": len(self.train_data),
"val": len(self.val_data),
"test": len(self.test_data),
"train": len(self.train_data_text),
"val": len(self.val_data_text),
"test": len(self.test_data_text),
}

@staticmethod
Expand Down Expand Up @@ -765,8 +765,8 @@ def collapse_contiguous_indices(ls):

def remap_ptb_passage_and_answer_spans(ptb_tokens, answer_span, moses, tokenizer_name):
# Start with PTB tokenized tokens
# The answer_span is also in ptb_token space. We first want to detokenize, and convert everything to
# space-tokenization space.
# The answer_span is also in ptb_token space. We first want to detokenize, and convert
# everything to space-tokenization space.

# Detokenize the passage. Everything we do will be based on the detokenized input,
# INCLUDING evaluation.
Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,7 +1798,7 @@ def __init__(self, path, max_seq_len, name, n_classes, **kw):

self.train_data_text = None
self.val_data_text = None
self.test_data = None
self.test_data_text = None
self.acc_scorer = BooleanAccuracy()
self.gender_parity_scorer = GenderParity()
self.val_metric = "%s_accuracy" % name
Expand Down

0 comments on commit d858734

Please sign in to comment.