Skip to content

Commit

Permalink
piqa (#1216)
Browse files Browse the repository at this point in the history
  • Loading branch information
wh629 authored Oct 26, 2020
1 parent 961bd57 commit 442a2b0
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 0 deletions.
1 change: 1 addition & 0 deletions guides/tasks/supported_tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
| MRPC | mrpc ||| mrpc | GLUE |
| Natural Questions | mrqa_natural_questions ||| mrqa_natural_questions | [MRQA](https://mrqa.github.io/) version of task |
| NewsQA | newsqa ||| newsqa | |
| PIQA | piqa ||| piqa | [PIQA](https://yonatanbisk.com/piqa/) |
| QAMR | qamr ||| qamr | |
| QA-SRL | qasrl ||| qasrl | |
| Quoref | quoref ||| quoref | |
Expand Down
1 change: 1 addition & 0 deletions jiant/scripts/download_data/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"qasrl",
"newsqa",
"mrqa_natural_questions",
"piqa",
}
DIRECT_DOWNLOAD_TASKS = set(
list(SQUAD_TASKS) + list(DIRECT_SUPERGLUE_TASKS_TO_DATA_URLS) + list(OTHER_DOWNLOAD_TASKS)
Expand Down
43 changes: 43 additions & 0 deletions jiant/scripts/download_data/dl_datasets/files_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def download_task_data_and_write_config(task_name: str, task_data_path: str, tas
download_mrqa_natural_questions_data_and_write_config(
task_name=task_name, task_data_path=task_data_path, task_config_path=task_config_path
)
elif task_name == "piqa":
download_piqa_data_and_write_config(
task_name=task_name, task_data_path=task_data_path, task_config_path=task_config_path
)
else:
raise KeyError(task_name)

Expand Down Expand Up @@ -590,3 +594,42 @@ def download_mrqa_natural_questions_data_and_write_config(
},
path=task_config_path,
)


def download_piqa_data_and_write_config(task_name: str, task_data_path: str, task_config_path: str):
os.makedirs(task_data_path, exist_ok=True)
download_utils.download_file(
"https://yonatanbisk.com/piqa/data/train.jsonl",
os.path.join(task_data_path, "train.jsonl"),
)
download_utils.download_file(
"https://yonatanbisk.com/piqa/data/train-labels.lst",
os.path.join(task_data_path, "train-labels.lst"),
)
download_utils.download_file(
"https://yonatanbisk.com/piqa/data/valid.jsonl",
os.path.join(task_data_path, "valid.jsonl"),
)
download_utils.download_file(
"https://yonatanbisk.com/piqa/data/valid-labels.lst",
os.path.join(task_data_path, "valid-labels.lst"),
)
download_utils.download_file(
"https://yonatanbisk.com/piqa/data/tests.jsonl",
os.path.join(task_data_path, "tests.jsonl"),
)

py_io.write_json(
data={
"task": task_name,
"paths": {
"train": os.path.join(task_data_path, "train.jsonl"),
"train_labels": os.path.join(task_data_path, "train-labels.lst"),
"val": os.path.join(task_data_path, "valid.jsonl"),
"val_labels": os.path.join(task_data_path, "valid-labels.lst"),
"test": os.path.join(task_data_path, "tests.jsonl"),
},
"name": task_name,
},
path=task_config_path,
)
1 change: 1 addition & 0 deletions jiant/tasks/evaluate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,7 @@ def get_evaluation_scheme_for_task(task) -> BaseEvaluationScheme:
tasks.XnliTask,
tasks.MCScriptTask,
tasks.ArctTask,
tasks.PiqaTask,
),
):
return SimpleAccuracyEvaluationScheme()
Expand Down
78 changes: 78 additions & 0 deletions jiant/tasks/lib/piqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from dataclasses import dataclass

from jiant.tasks.lib.templates.shared import labels_to_bimap
from jiant.tasks.lib.templates import multiple_choice as mc_template
from jiant.utils.python.io import read_json_lines, read_file_lines


@dataclass
class Example(mc_template.Example):
@property
def task(self):
return PiqaTask


@dataclass
class TokenizedExample(mc_template.TokenizedExample):
pass


@dataclass
class DataRow(mc_template.DataRow):
pass


@dataclass
class Batch(mc_template.Batch):
pass


class PiqaTask(mc_template.AbstractMultipleChoiceTask):
Example = Example
TokenizedExample = Example
DataRow = DataRow
Batch = Batch

CHOICE_KEYS = [0, 1]
CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS)
NUM_CHOICES = len(CHOICE_KEYS)

def get_train_examples(self):
return self._create_examples(
lines=zip(
read_json_lines(self.train_path),
read_file_lines(self.path_dict["train_labels"], strip_lines=True),
),
set_type="train",
)

def get_val_examples(self):
return self._create_examples(
lines=zip(
read_json_lines(self.val_path),
read_file_lines(self.path_dict["val_labels"], strip_lines=True),
),
set_type="val",
)

def get_test_examples(self):
return self._create_examples(
lines=zip(read_json_lines(self.test_path), read_json_lines(self.test_path)),
set_type="test",
)

@classmethod
def _create_examples(cls, lines, set_type):
examples = []

for i, (ex, label_string) in enumerate(lines):
examples.append(
Example(
guid="%s-%s" % (set_type, i),
prompt=ex["goal"],
choice_list=[ex["sol1"], ex["sol2"]],
label=int(label_string) if set_type != "test" else cls.CHOICE_KEYS[-1],
)
)

return examples
2 changes: 2 additions & 0 deletions jiant/tasks/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from jiant.tasks.lib.xquad import XquadTask
from jiant.tasks.lib.mcscript import MCScriptTask
from jiant.tasks.lib.arct import ArctTask
from jiant.tasks.lib.piqa import PiqaTask

from jiant.tasks.core import Task
from jiant.utils.python.io import read_json
Expand Down Expand Up @@ -139,6 +140,7 @@
"xquad": XquadTask,
"mcscript": MCScriptTask,
"arct": ArctTask,
"piqa": PiqaTask,
}


Expand Down

0 comments on commit 442a2b0

Please sign in to comment.