From e8e00d486375dee01d1a2368f8fe836680bdabd1 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 30 Nov 2023 16:37:52 +0100 Subject: [PATCH] add test_squad_f1.py (wip) --- tests/metrics/test_squad_f1.py | 119 +++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 tests/metrics/test_squad_f1.py diff --git a/tests/metrics/test_squad_f1.py b/tests/metrics/test_squad_f1.py new file mode 100644 index 000000000..197358175 --- /dev/null +++ b/tests/metrics/test_squad_f1.py @@ -0,0 +1,119 @@ +from pie_modules.annotations import ExtractiveAnswer, Question +from pie_modules.documents import ExtractiveQADocument +from pie_modules.metrics import SQuADF1 + + +def test_squad_f1_exact_match(): + metric = SQuADF1() + + # create a test document + # sample edit + doc = ExtractiveQADocument(text="This is a test document.") + # add a question + q1 = Question(text="What is this?") + doc.questions.append(q1) + # add a gold answer + doc.answers.append(ExtractiveAnswer(question=q1, start=8, end=23)) + assert str(doc.answers[0]) == "a test document" + # add a predicted answer + doc.answers.predictions.append(ExtractiveAnswer(question=q1, start=8, end=23, score=0.9)) + assert str(doc.answers.predictions[0]) == str(doc.answers[0]) + + metric._update(doc) + + # assert internal state + assert metric.exact_scores == {"text=This is a test document.,question=What is this?": 1} + assert metric.f1_scores == {"text=This is a test document.,question=What is this?": 1.0} + assert metric.has_answer_qids == ["text=This is a test document.,question=What is this?"] + assert metric.no_answer_qids == [] + assert metric.qas_id_to_has_answer == { + "text=This is a test document.,question=What is this?": True + } + + metric_values = metric._compute() + assert metric_values == { + "HasAns_exact": 100.0, + "HasAns_f1": 100.0, + "HasAns_total": 1, + "exact": 100.0, + "f1": 100.0, + "total": 1, + } + + +def test_squad_f1_exact_match_added_article(): + metric = SQuADF1() + + # create a test document + doc = ExtractiveQADocument(text="This is a test document.") + # add a question + q1 = Question(text="What is this?") + doc.questions.append(q1) + # add a gold answer for q1 + doc.answers.append(ExtractiveAnswer(question=q1, start=8, end=23)) + assert str(doc.answers[0]) == "a test document" + # add a predicted answer for q1 + doc.answers.predictions.append(ExtractiveAnswer(question=q1, start=10, end=23, score=0.9)) + assert str(doc.answers.predictions[0]) == "test document" + # the spans are not the same! + assert str(doc.answers.predictions[0]) != str(doc.answers[0]) + + metric._update(doc) + # assert internal state + assert metric.exact_scores == {"text=This is a test document.,question=What is this?": 1} + assert metric.f1_scores == {"text=This is a test document.,question=What is this?": 1.0} + assert metric.has_answer_qids == ["text=This is a test document.,question=What is this?"] + assert metric.no_answer_qids == [] + assert metric.qas_id_to_has_answer == { + "text=This is a test document.,question=What is this?": True + } + + metric_values = metric._compute() + assert metric_values == { + "HasAns_exact": 100.0, + "HasAns_f1": 100.0, + "HasAns_total": 1, + "exact": 100.0, + "f1": 100.0, + "total": 1, + } + + +def test_squad_f1_span_mismatch(): + metric = SQuADF1() + + # create a test document + doc = ExtractiveQADocument(text="This is a test document.") + # add a question + q1 = Question(text="What is this?") + doc.questions.append(q1) + # add a gold answer for q1 + doc.answers.append(ExtractiveAnswer(question=q1, start=8, end=23)) + assert str(doc.answers[0]) == "a test document" + # add a predicted answer for q1 + doc.answers.predictions.append(ExtractiveAnswer(question=q1, start=15, end=23, score=0.9)) + assert str(doc.answers.predictions[0]) == "document" + # the spans are not the same! + assert str(doc.answers.predictions[0]) != str(doc.answers[0]) + + metric._update(doc) + # assert internal state + assert metric.exact_scores == {"text=This is a test document.,question=What is this?": 0} + assert metric.f1_scores == { + "text=This is a test document.,question=What is this?": 0.6666666666666666 + } + assert metric.has_answer_qids == ["text=This is a test document.,question=What is this?"] + assert metric.no_answer_qids == [] + assert metric.qas_id_to_has_answer == { + "text=This is a test document.,question=What is this?": True + } + + metric_values = metric._compute() + assert metric_values == { + "HasAns_exact": 0.0, + "HasAns_f1": 66.66666666666666, + "HasAns_total": 1, + "exact": 0.0, + "f1": 66.66666666666666, + "total": 1, + }