diff --git a/src/pie_modules/documents.py b/src/pie_modules/documents.py index ae74f16c1..a7a77cb7e 100644 --- a/src/pie_modules/documents.py +++ b/src/pie_modules/documents.py @@ -1,5 +1,6 @@ import dataclasses +from pytorch_ie.annotations import BinaryRelation, LabeledSpan from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument @@ -7,7 +8,7 @@ @dataclasses.dataclass -class ExtractiveQADocument(TextBasedDocument): +class TextDocumentWithQuestionsAndExtractiveAnswers(TextBasedDocument): """A text based PIE document with annotations for extractive question answering.""" questions: AnnotationList[Question] = annotation_field() @@ -20,7 +21,7 @@ class ExtractiveQADocument(TextBasedDocument): @dataclasses.dataclass -class TokenizedExtractiveQADocument(TokenBasedDocument): +class TokenDocumentWithQuestionsAndExtractiveAnswers(TokenBasedDocument): """A tokenized PIE document with annotations for extractive question answering.""" questions: AnnotationList[Question] = annotation_field() @@ -30,3 +31,37 @@ class TokenizedExtractiveQADocument(TokenBasedDocument): answers: AnnotationList[ExtractiveAnswer] = annotation_field( named_targets={"base": "tokens", "questions": "questions"} ) + + +# backwards compatibility +ExtractiveQADocument = TextDocumentWithQuestionsAndExtractiveAnswers +TokenizedExtractiveQADocument = TokenDocumentWithQuestionsAndExtractiveAnswers + + +@dataclasses.dataclass +class TokenDocumentWithLabeledSpans(TokenBasedDocument): + labeled_spans: AnnotationList[LabeledSpan] = annotation_field(target="tokens") + + +@dataclasses.dataclass +class TokenDocumentWithLabeledPartitions(TokenBasedDocument): + labeled_partitions: AnnotationList[LabeledSpan] = annotation_field(target="tokens") + + +@dataclasses.dataclass +class TokenDocumentWithLabeledSpansAndLabeledPartitions( + TokenDocumentWithLabeledSpans, TokenDocumentWithLabeledPartitions +): + pass + + +@dataclasses.dataclass +class TokenDocumentWithLabeledSpansAndBinaryRelations(TokenDocumentWithLabeledSpans): + binary_relations: AnnotationList[BinaryRelation] = annotation_field(target="labeled_spans") + + +@dataclasses.dataclass +class TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( + TokenDocumentWithLabeledSpansAndBinaryRelations, TokenDocumentWithLabeledPartitions +): + pass diff --git a/src/pie_modules/metrics/squad_f1.py b/src/pie_modules/metrics/squad_f1.py index 3a8ce00c6..7ba474027 100644 --- a/src/pie_modules/metrics/squad_f1.py +++ b/src/pie_modules/metrics/squad_f1.py @@ -8,7 +8,7 @@ import pandas as pd from pytorch_ie.core import DocumentMetric -from pie_modules.documents import ExtractiveQADocument +from pie_modules.documents import TextDocumentWithQuestionsAndExtractiveAnswers logger = logging.getLogger(__name__) @@ -42,7 +42,7 @@ def reset(self): self.has_answer_qids = [] self.no_answer_qids = [] - def _update(self, document: ExtractiveQADocument): + def _update(self, document: TextDocumentWithQuestionsAndExtractiveAnswers): gold_answers_for_questions = defaultdict(list) predicted_answers_for_questions = defaultdict(list) for ann in document.answers: diff --git a/src/pie_modules/taskmodules/extractive_question_answering.py b/src/pie_modules/taskmodules/extractive_question_answering.py index f58b0103c..d68b0d112 100644 --- a/src/pie_modules/taskmodules/extractive_question_answering.py +++ b/src/pie_modules/taskmodules/extractive_question_answering.py @@ -13,7 +13,10 @@ from pie_modules.annotations import ExtractiveAnswer, Question from pie_modules.document.processing import tokenize_document -from pie_modules.documents import ExtractiveQADocument, TokenizedExtractiveQADocument +from pie_modules.documents import ( + TextDocumentWithQuestionsAndExtractiveAnswers, + TokenDocumentWithQuestionsAndExtractiveAnswers, +) logger = logging.getLogger(__name__) @@ -29,7 +32,7 @@ class TargetEncoding: TaskEncodingType: TypeAlias = TaskEncoding[ - ExtractiveQADocument, + TextDocumentWithQuestionsAndExtractiveAnswers, InputEncoding, TargetEncoding, ] @@ -67,7 +70,7 @@ class ExtractiveQuestionAnsweringTaskModule(TaskModule): tokenize_kwargs: Additional keyword arguments for the tokenizer. Defaults to None. """ - DOCUMENT_TYPE = ExtractiveQADocument + DOCUMENT_TYPE = TextDocumentWithQuestionsAndExtractiveAnswers def __init__( self, @@ -124,7 +127,7 @@ def encode_input( truncation="only_second", max_length=self.max_length, return_overflowing_tokens=True, - result_document_type=TokenizedExtractiveQADocument, + result_document_type=TokenDocumentWithQuestionsAndExtractiveAnswers, strict_span_conversion=False, verbose=False, **self.tokenize_kwargs, diff --git a/src/pie_modules/taskmodules/token_classification.py b/src/pie_modules/taskmodules/token_classification.py index 4e43483fc..edfef845d 100644 --- a/src/pie_modules/taskmodules/token_classification.py +++ b/src/pie_modules/taskmodules/token_classification.py @@ -8,7 +8,6 @@ """ import copy -import dataclasses import logging from typing import ( Any, @@ -25,14 +24,13 @@ import torch import torch.nn.functional as F -from pytorch_ie import AnnotationLayer, annotation_field +from pytorch_ie import AnnotationLayer from pytorch_ie.annotations import LabeledSpan from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.documents import ( TextDocument, TextDocumentWithLabeledSpans, TextDocumentWithLabeledSpansAndLabeledPartitions, - TokenBasedDocument, ) from pytorch_ie.models.transformer_token_classification import ( ModelOutputType, @@ -47,6 +45,10 @@ token_based_document_to_text_based, tokenize_document, ) +from pie_modules.documents import ( + TokenDocumentWithLabeledSpans, + TokenDocumentWithLabeledSpansAndLabeledPartitions, +) DocumentType: TypeAlias = TextDocument @@ -71,16 +73,6 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass -class TokenDocumentWithLabeledSpans(TokenBasedDocument): - labeled_spans: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens") - - -@dataclasses.dataclass -class TokenDocumentWithLabeledSpansAndLabeledPartitions(TokenDocumentWithLabeledSpans): - labeled_partitions: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens") - - @TaskModule.register() class TokenClassificationTaskModule(TaskModuleType): """Taskmodule for span prediction (e.g. NER) as token classification. diff --git a/tests/metrics/test_squad_f1.py b/tests/metrics/test_squad_f1.py index 7e06544a3..337081e23 100644 --- a/tests/metrics/test_squad_f1.py +++ b/tests/metrics/test_squad_f1.py @@ -1,7 +1,7 @@ import logging from pie_modules.annotations import ExtractiveAnswer, Question -from pie_modules.documents import ExtractiveQADocument +from pie_modules.documents import TextDocumentWithQuestionsAndExtractiveAnswers from pie_modules.metrics import SQuADF1 @@ -10,7 +10,7 @@ def test_squad_f1_exact_match(caplog): # create a test document # sample edit - doc = ExtractiveQADocument(text="This is a test document.") + doc = TextDocumentWithQuestionsAndExtractiveAnswers(text="This is a test document.") # add a question q1 = Question(text="What is this?") doc.questions.append(q1) @@ -62,7 +62,7 @@ def test_squad_f1_exact_match_added_article(): metric = SQuADF1() # create a test document - doc = ExtractiveQADocument( + doc = TextDocumentWithQuestionsAndExtractiveAnswers( text="This is a test document.", id="eqa_doc_with_exact_match_added_article" ) # add a question @@ -100,7 +100,7 @@ def test_squad_f1_partly_span_mismatch(): metric = SQuADF1() # create a test document - doc = ExtractiveQADocument( + doc = TextDocumentWithQuestionsAndExtractiveAnswers( text="This is a test document.", id="eqa_doc_with_partly_span_mismatch" ) # add a question @@ -138,7 +138,7 @@ def test_squad_f1_full_span_mismatch(): metric = SQuADF1() # create a test document - doc = ExtractiveQADocument( + doc = TextDocumentWithQuestionsAndExtractiveAnswers( text="This is a test document.", id="eqa_doc_with_full_span_mismatch" ) # add a question @@ -176,7 +176,7 @@ def test_squad_f1_no_predicted_answers(): metric = SQuADF1() # create a test document - doc = ExtractiveQADocument( + doc = TextDocumentWithQuestionsAndExtractiveAnswers( text="This is a test document.", id="eqa_doc_without_predicted_answers" ) # add a question @@ -209,7 +209,9 @@ def test_squad_f1_no_gold_answers(): metric = SQuADF1() # create a test document - doc = ExtractiveQADocument(text="This is a test document.", id="eqa_doc_without_gold_answers") + doc = TextDocumentWithQuestionsAndExtractiveAnswers( + text="This is a test document.", id="eqa_doc_without_gold_answers" + ) # add a question q1 = Question(text="What is this?") doc.questions.append(q1) @@ -240,7 +242,7 @@ def test_squad_f1_empty_document(): metric = SQuADF1() # create a test document - doc = ExtractiveQADocument(text="", id="eqa_doc_with_empty_text") + doc = TextDocumentWithQuestionsAndExtractiveAnswers(text="", id="eqa_doc_with_empty_text") # add a question q1 = Question(text="What is this?") doc.questions.append(q1) diff --git a/tests/models/test_extractive_question_answering.py b/tests/models/test_extractive_question_answering.py index c9dde8642..9d8b6c1cb 100644 --- a/tests/models/test_extractive_question_answering.py +++ b/tests/models/test_extractive_question_answering.py @@ -6,7 +6,7 @@ from pytorch_lightning import Trainer from pie_modules.annotations import ExtractiveAnswer, Question -from pie_modules.documents import ExtractiveQADocument +from pie_modules.documents import TextDocumentWithQuestionsAndExtractiveAnswers from pie_modules.models.simple_extractive_question_answering import ( SimpleExtractiveQuestionAnsweringModel, ) @@ -20,15 +20,19 @@ @pytest.fixture def documents(): - document0 = ExtractiveQADocument(text="This is a test document", id="doc0") + document0 = TextDocumentWithQuestionsAndExtractiveAnswers( + text="This is a test document", id="doc0" + ) document0.questions.append(Question(text="What is the first word?")) document0.answers.append(ExtractiveAnswer(question=document0.questions[0], start=0, end=3)) - document1 = ExtractiveQADocument(text="Oranges are orange in color.", id="doc1") + document1 = TextDocumentWithQuestionsAndExtractiveAnswers( + text="Oranges are orange in color.", id="doc1" + ) document1.questions.append(Question(text="What color are oranges?")) document1.answers.append(ExtractiveAnswer(question=document1.questions[0], start=23, end=27)) - document2 = ExtractiveQADocument( + document2 = TextDocumentWithQuestionsAndExtractiveAnswers( text="This is a test document that has two questions attached to it.", id="doc2" ) document2.questions.append(Question(text="What type of document is this?")) diff --git a/tests/taskmodules/test_extractive_question_answering.py b/tests/taskmodules/test_extractive_question_answering.py index 1b145dbc9..77b78cea3 100644 --- a/tests/taskmodules/test_extractive_question_answering.py +++ b/tests/taskmodules/test_extractive_question_answering.py @@ -4,7 +4,7 @@ from pytorch_ie.core import AnnotationList from pie_modules.annotations import ExtractiveAnswer, Question -from pie_modules.documents import ExtractiveQADocument +from pie_modules.documents import TextDocumentWithQuestionsAndExtractiveAnswers from pie_modules.taskmodules.extractive_question_answering import ( ExtractiveQuestionAnsweringTaskModule, ) @@ -12,7 +12,9 @@ @pytest.fixture() def document(): - document = ExtractiveQADocument(text="This is a test document", id="doc0") + document = TextDocumentWithQuestionsAndExtractiveAnswers( + text="This is a test document", id="doc0" + ) document.questions.append(Question(text="What is the first word?")) document.answers.append(ExtractiveAnswer(question=document.questions[0], start=0, end=4)) assert str(document.answers[0]) == "This" @@ -21,7 +23,9 @@ def document(): @pytest.fixture() def document1(): - document1 = ExtractiveQADocument(text="This is the second document", id="doc1") + document1 = TextDocumentWithQuestionsAndExtractiveAnswers( + text="This is the second document", id="doc1" + ) document1.questions.append(Question(text="Which document is this?")) document1.answers.append(ExtractiveAnswer(question=document1.questions[0], start=13, end=18)) assert str(document1.answers[0]) == "second" @@ -30,14 +34,16 @@ def document1(): @pytest.fixture() def document_with_no_answer(): - document = ExtractiveQADocument(text="This is a test document", id="document_with_no_answer") + document = TextDocumentWithQuestionsAndExtractiveAnswers( + text="This is a test document", id="document_with_no_answer" + ) document.questions.append(Question(text="What is the first word?")) return document @pytest.fixture() def document_with_multiple_answers(): - document = ExtractiveQADocument( + document = TextDocumentWithQuestionsAndExtractiveAnswers( text="This is a test document", id="document_with_multiple_answers" ) document.questions.append(Question(text="What is the first word?")) diff --git a/tests/test_documents.py b/tests/test_documents.py new file mode 100644 index 000000000..feee80498 --- /dev/null +++ b/tests/test_documents.py @@ -0,0 +1,227 @@ +from pytorch_ie.annotations import BinaryRelation, LabeledSpan + +from pie_modules.annotations import ExtractiveAnswer, Question +from pie_modules.documents import ( + TextDocumentWithQuestionsAndExtractiveAnswers, + TokenDocumentWithLabeledPartitions, + TokenDocumentWithLabeledSpans, + TokenDocumentWithLabeledSpansAndBinaryRelations, + TokenDocumentWithLabeledSpansAndLabeledPartitions, + TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, + TokenDocumentWithQuestionsAndExtractiveAnswers, +) + + +def test_token_document_with_labeled_spans(): + doc = TokenDocumentWithLabeledSpans( + tokens=("This", "is", "a", "sentence", "."), id="token_document_with_labeled_spans" + ) + e1 = LabeledSpan(start=0, end=1, label="entity") + doc.labeled_spans.append(e1) + assert str(e1) == "('This',)" + e2 = LabeledSpan(start=2, end=4, label="entity") + doc.labeled_spans.append(e2) + assert str(e2) == "('a', 'sentence')" + + # test (de-)serialization + doc_copy = doc.copy() + assert doc == doc_copy + + +def test_token_document_with_labeled_partitions(): + doc = TokenDocumentWithLabeledPartitions( + tokens=( + "This", + "is", + "a", + "sentence", + ".", + "And", + "this", + "is", + "another", + "sentence", + ".", + ), + id="token_document_with_labeled_partitions", + ) + sent1 = LabeledSpan(start=0, end=5, label="sentence") + doc.labeled_partitions.append(sent1) + assert str(sent1) == "('This', 'is', 'a', 'sentence', '.')" + sent2 = LabeledSpan(start=5, end=11, label="sentence") + doc.labeled_partitions.append(sent2) + assert str(sent2) == "('And', 'this', 'is', 'another', 'sentence', '.')" + + # test (de-)serialization + doc_copy = doc.copy() + assert doc == doc_copy + + +def test_token_document_with_labeled_spans_and_labeled_partitions(): + doc = TokenDocumentWithLabeledSpansAndLabeledPartitions( + tokens=( + "This", + "is", + "a", + "sentence", + ".", + "And", + "this", + "is", + "another", + "sentence", + ".", + ), + id="token_document_with_labeled_spans_and_labeled_partitions", + ) + e1 = LabeledSpan(start=0, end=1, label="entity") + doc.labeled_spans.append(e1) + assert str(e1) == "('This',)" + e2 = LabeledSpan(start=2, end=4, label="entity") + doc.labeled_spans.append(e2) + assert str(e2) == "('a', 'sentence')" + sent1 = LabeledSpan(start=0, end=5, label="sentence") + doc.labeled_partitions.append(sent1) + assert str(sent1) == "('This', 'is', 'a', 'sentence', '.')" + sent2 = LabeledSpan(start=5, end=11, label="sentence") + doc.labeled_partitions.append(sent2) + assert str(sent2) == "('And', 'this', 'is', 'another', 'sentence', '.')" + + # test (de-)serialization + doc_copy = doc.copy() + assert doc == doc_copy + + +def test_token_document_with_labeled_spans_and_binary_relations(): + doc = TokenDocumentWithLabeledSpansAndBinaryRelations( + tokens=( + "This", + "is", + "a", + "sentence", + ".", + "And", + "this", + "is", + "another", + "sentence", + ".", + ), + id="token_document_with_labeled_spans_and_binary_relations", + ) + e1 = LabeledSpan(start=0, end=1, label="entity") + doc.labeled_spans.append(e1) + assert str(e1) == "('This',)" + e2 = LabeledSpan(start=2, end=4, label="entity") + doc.labeled_spans.append(e2) + assert str(e2) == "('a', 'sentence')" + r1 = BinaryRelation(head=e1, tail=e2, label="relation") + doc.binary_relations.append(r1) + assert str(r1.head) == "('This',)" + assert str(r1.tail) == "('a', 'sentence')" + assert r1.label == "relation" + + # test (de-)serialization + doc_copy = doc.copy() + assert doc == doc_copy + + +def test_token_document_with_labeled_spans_binary_relations_and_labeled_partitions(): + doc = TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( + tokens=( + "This", + "is", + "a", + "sentence", + ".", + "And", + "this", + "is", + "another", + "sentence", + ".", + ), + id="token_document_with_labeled_spans_binary_relations_and_labeled_partitions", + ) + e1 = LabeledSpan(start=0, end=1, label="entity") + doc.labeled_spans.append(e1) + assert str(e1) == "('This',)" + e2 = LabeledSpan(start=2, end=4, label="entity") + doc.labeled_spans.append(e2) + assert str(e2) == "('a', 'sentence')" + r1 = BinaryRelation(head=e1, tail=e2, label="relation") + doc.binary_relations.append(r1) + assert str(r1.head) == "('This',)" + assert str(r1.tail) == "('a', 'sentence')" + assert r1.label == "relation" + sent1 = LabeledSpan(start=0, end=5, label="sentence") + doc.labeled_partitions.append(sent1) + assert str(sent1) == "('This', 'is', 'a', 'sentence', '.')" + sent2 = LabeledSpan(start=5, end=11, label="sentence") + doc.labeled_partitions.append(sent2) + assert str(sent2) == "('And', 'this', 'is', 'another', 'sentence', '.')" + + # test (de-)serialization + doc_copy = doc.copy() + assert doc == doc_copy + + +def test_extractive_qa_document(): + doc = TextDocumentWithQuestionsAndExtractiveAnswers( + text="This is a sentence. And that is another sentence.", id="extractive_qa_document" + ) + q1 = Question(text="What is this?") + doc.questions.append(q1) + q2 = Question(text="What is that?") + doc.questions.append(q2) + + a1 = ExtractiveAnswer(start=8, end=18, question=q1) + doc.answers.append(a1) + assert str(a1.question) == "What is this?" + assert str(a1) == "a sentence" + + a2 = ExtractiveAnswer(start=32, end=48, question=q2) + doc.answers.append(a2) + assert str(a2.question) == "What is that?" + assert str(a2) == "another sentence" + + # test (de-)serialization + doc_copy = doc.copy() + assert doc == doc_copy + + +def test_tokenized_extractive_qa_document(): + doc = TokenDocumentWithQuestionsAndExtractiveAnswers( + tokens=( + "This", + "is", + "a", + "sentence", + ".", + "And", + "that", + "is", + "another", + "sentence", + ".", + ), + id="tokenized_extractive_qa_document", + ) + q1 = Question(text="What is this?") + doc.questions.append(q1) + q2 = Question(text="What is that?") + doc.questions.append(q2) + + a1 = ExtractiveAnswer(start=2, end=4, question=q1) + doc.answers.append(a1) + assert str(a1.question) == "What is this?" + assert str(a1) == "('a', 'sentence')" + + a2 = ExtractiveAnswer(start=8, end=10, question=q2) + doc.answers.append(a2) + assert str(a2.question) == "What is that?" + assert str(a2) == "('another', 'sentence')" + + # test (de-)serialization + doc_copy = doc.copy() + assert doc == doc_copy