From 4c95331744f2bc75b867226ded50ac3546ca1210 Mon Sep 17 00:00:00 2001 From: Clara Vania Date: Mon, 26 Oct 2020 12:35:05 -0400 Subject: [PATCH] Fix tokenization for arc_easy and arc_challenge (#1214) Co-authored-by: jeswan <57466294+jeswan@users.noreply.github.com> --- jiant/tasks/lib/arc_challenge.py | 4 ++-- jiant/tasks/lib/arc_easy.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jiant/tasks/lib/arc_challenge.py b/jiant/tasks/lib/arc_challenge.py index 4a17c34cc..1bc91c798 100644 --- a/jiant/tasks/lib/arc_challenge.py +++ b/jiant/tasks/lib/arc_challenge.py @@ -61,7 +61,7 @@ def _create_examples(cls, lines, set_type): label = line["answerKey"] if label in potential_label_map: label = potential_label_map[label] - choice_list = [d["text"] for d in line["question"]["choices"]] + choice_list = [d for d in line["choices"]["text"]] filler_choice_list = ["." for i in range(NUM_CHOICES - len(choice_list))] choice_list = choice_list + filler_choice_list assert len(choice_list) == NUM_CHOICES @@ -69,7 +69,7 @@ def _create_examples(cls, lines, set_type): examples.append( Example( guid="%s-%s" % (set_type, i), - prompt=line["question"]["stem"], + prompt=line["question"], choice_list=choice_list, label=label, ) diff --git a/jiant/tasks/lib/arc_easy.py b/jiant/tasks/lib/arc_easy.py index b04f1443b..9511e124e 100644 --- a/jiant/tasks/lib/arc_easy.py +++ b/jiant/tasks/lib/arc_easy.py @@ -61,7 +61,7 @@ def _create_examples(cls, lines, set_type): label = line["answerKey"] if label in potential_label_map: label = potential_label_map[label] - choice_list = [d["text"] for d in line["question"]["choices"]] + choice_list = [d for d in line["choices"]["text"]] filler_choice_list = ["." for i in range(NUM_CHOICES - len(choice_list))] choice_list = choice_list + filler_choice_list assert len(choice_list) == NUM_CHOICES @@ -69,7 +69,7 @@ def _create_examples(cls, lines, set_type): examples.append( Example( guid="%s-%s" % (set_type, i), - prompt=line["question"]["stem"], + prompt=line["question"], choice_list=choice_list, label=label, )