Skip to content

Commit

Permalink
Merge pull request #21 from ArneBinder/token_document_types
Browse files Browse the repository at this point in the history
add token document types
  • Loading branch information
ArneBinder authored Dec 20, 2023
2 parents 26ee964 + bd492c7 commit e85a98b
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 38 deletions.
39 changes: 37 additions & 2 deletions src/pie_modules/documents.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import dataclasses

from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument

from pie_modules.annotations import ExtractiveAnswer, Question


@dataclasses.dataclass
class ExtractiveQADocument(TextBasedDocument):
class TextDocumentWithQuestionsAndExtractiveAnswers(TextBasedDocument):
"""A text based PIE document with annotations for extractive question answering."""

questions: AnnotationList[Question] = annotation_field()
Expand All @@ -20,7 +21,7 @@ class ExtractiveQADocument(TextBasedDocument):


@dataclasses.dataclass
class TokenizedExtractiveQADocument(TokenBasedDocument):
class TokenDocumentWithQuestionsAndExtractiveAnswers(TokenBasedDocument):
"""A tokenized PIE document with annotations for extractive question answering."""

questions: AnnotationList[Question] = annotation_field()
Expand All @@ -30,3 +31,37 @@ class TokenizedExtractiveQADocument(TokenBasedDocument):
answers: AnnotationList[ExtractiveAnswer] = annotation_field(
named_targets={"base": "tokens", "questions": "questions"}
)


# backwards compatibility
ExtractiveQADocument = TextDocumentWithQuestionsAndExtractiveAnswers
TokenizedExtractiveQADocument = TokenDocumentWithQuestionsAndExtractiveAnswers


@dataclasses.dataclass
class TokenDocumentWithLabeledSpans(TokenBasedDocument):
labeled_spans: AnnotationList[LabeledSpan] = annotation_field(target="tokens")


@dataclasses.dataclass
class TokenDocumentWithLabeledPartitions(TokenBasedDocument):
labeled_partitions: AnnotationList[LabeledSpan] = annotation_field(target="tokens")


@dataclasses.dataclass
class TokenDocumentWithLabeledSpansAndLabeledPartitions(
TokenDocumentWithLabeledSpans, TokenDocumentWithLabeledPartitions
):
pass


@dataclasses.dataclass
class TokenDocumentWithLabeledSpansAndBinaryRelations(TokenDocumentWithLabeledSpans):
binary_relations: AnnotationList[BinaryRelation] = annotation_field(target="labeled_spans")


@dataclasses.dataclass
class TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
TokenDocumentWithLabeledSpansAndBinaryRelations, TokenDocumentWithLabeledPartitions
):
pass
4 changes: 2 additions & 2 deletions src/pie_modules/metrics/squad_f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
from pytorch_ie.core import DocumentMetric

from pie_modules.documents import ExtractiveQADocument
from pie_modules.documents import TextDocumentWithQuestionsAndExtractiveAnswers

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,7 +42,7 @@ def reset(self):
self.has_answer_qids = []
self.no_answer_qids = []

def _update(self, document: ExtractiveQADocument):
def _update(self, document: TextDocumentWithQuestionsAndExtractiveAnswers):
gold_answers_for_questions = defaultdict(list)
predicted_answers_for_questions = defaultdict(list)
for ann in document.answers:
Expand Down
11 changes: 7 additions & 4 deletions src/pie_modules/taskmodules/extractive_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

from pie_modules.annotations import ExtractiveAnswer, Question
from pie_modules.document.processing import tokenize_document
from pie_modules.documents import ExtractiveQADocument, TokenizedExtractiveQADocument
from pie_modules.documents import (
TextDocumentWithQuestionsAndExtractiveAnswers,
TokenDocumentWithQuestionsAndExtractiveAnswers,
)

logger = logging.getLogger(__name__)

Expand All @@ -29,7 +32,7 @@ class TargetEncoding:


TaskEncodingType: TypeAlias = TaskEncoding[
ExtractiveQADocument,
TextDocumentWithQuestionsAndExtractiveAnswers,
InputEncoding,
TargetEncoding,
]
Expand Down Expand Up @@ -67,7 +70,7 @@ class ExtractiveQuestionAnsweringTaskModule(TaskModule):
tokenize_kwargs: Additional keyword arguments for the tokenizer. Defaults to None.
"""

DOCUMENT_TYPE = ExtractiveQADocument
DOCUMENT_TYPE = TextDocumentWithQuestionsAndExtractiveAnswers

def __init__(
self,
Expand Down Expand Up @@ -124,7 +127,7 @@ def encode_input(
truncation="only_second",
max_length=self.max_length,
return_overflowing_tokens=True,
result_document_type=TokenizedExtractiveQADocument,
result_document_type=TokenDocumentWithQuestionsAndExtractiveAnswers,
strict_span_conversion=False,
verbose=False,
**self.tokenize_kwargs,
Expand Down
18 changes: 5 additions & 13 deletions src/pie_modules/taskmodules/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""

import copy
import dataclasses
import logging
from typing import (
Any,
Expand All @@ -25,14 +24,13 @@

import torch
import torch.nn.functional as F
from pytorch_ie import AnnotationLayer, annotation_field
from pytorch_ie import AnnotationLayer
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.core import TaskEncoding, TaskModule
from pytorch_ie.documents import (
TextDocument,
TextDocumentWithLabeledSpans,
TextDocumentWithLabeledSpansAndLabeledPartitions,
TokenBasedDocument,
)
from pytorch_ie.models.transformer_token_classification import (
ModelOutputType,
Expand All @@ -47,6 +45,10 @@
token_based_document_to_text_based,
tokenize_document,
)
from pie_modules.documents import (
TokenDocumentWithLabeledSpans,
TokenDocumentWithLabeledSpansAndLabeledPartitions,
)

DocumentType: TypeAlias = TextDocument

Expand All @@ -71,16 +73,6 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass
class TokenDocumentWithLabeledSpans(TokenBasedDocument):
labeled_spans: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens")


@dataclasses.dataclass
class TokenDocumentWithLabeledSpansAndLabeledPartitions(TokenDocumentWithLabeledSpans):
labeled_partitions: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens")


@TaskModule.register()
class TokenClassificationTaskModule(TaskModuleType):
"""Taskmodule for span prediction (e.g. NER) as token classification.
Expand Down
18 changes: 10 additions & 8 deletions tests/metrics/test_squad_f1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from pie_modules.annotations import ExtractiveAnswer, Question
from pie_modules.documents import ExtractiveQADocument
from pie_modules.documents import TextDocumentWithQuestionsAndExtractiveAnswers
from pie_modules.metrics import SQuADF1


Expand All @@ -10,7 +10,7 @@ def test_squad_f1_exact_match(caplog):

# create a test document
# sample edit
doc = ExtractiveQADocument(text="This is a test document.")
doc = TextDocumentWithQuestionsAndExtractiveAnswers(text="This is a test document.")
# add a question
q1 = Question(text="What is this?")
doc.questions.append(q1)
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_squad_f1_exact_match_added_article():
metric = SQuADF1()

# create a test document
doc = ExtractiveQADocument(
doc = TextDocumentWithQuestionsAndExtractiveAnswers(
text="This is a test document.", id="eqa_doc_with_exact_match_added_article"
)
# add a question
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_squad_f1_partly_span_mismatch():
metric = SQuADF1()

# create a test document
doc = ExtractiveQADocument(
doc = TextDocumentWithQuestionsAndExtractiveAnswers(
text="This is a test document.", id="eqa_doc_with_partly_span_mismatch"
)
# add a question
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_squad_f1_full_span_mismatch():
metric = SQuADF1()

# create a test document
doc = ExtractiveQADocument(
doc = TextDocumentWithQuestionsAndExtractiveAnswers(
text="This is a test document.", id="eqa_doc_with_full_span_mismatch"
)
# add a question
Expand Down Expand Up @@ -176,7 +176,7 @@ def test_squad_f1_no_predicted_answers():
metric = SQuADF1()

# create a test document
doc = ExtractiveQADocument(
doc = TextDocumentWithQuestionsAndExtractiveAnswers(
text="This is a test document.", id="eqa_doc_without_predicted_answers"
)
# add a question
Expand Down Expand Up @@ -209,7 +209,9 @@ def test_squad_f1_no_gold_answers():
metric = SQuADF1()

# create a test document
doc = ExtractiveQADocument(text="This is a test document.", id="eqa_doc_without_gold_answers")
doc = TextDocumentWithQuestionsAndExtractiveAnswers(
text="This is a test document.", id="eqa_doc_without_gold_answers"
)
# add a question
q1 = Question(text="What is this?")
doc.questions.append(q1)
Expand Down Expand Up @@ -240,7 +242,7 @@ def test_squad_f1_empty_document():
metric = SQuADF1()

# create a test document
doc = ExtractiveQADocument(text="", id="eqa_doc_with_empty_text")
doc = TextDocumentWithQuestionsAndExtractiveAnswers(text="", id="eqa_doc_with_empty_text")
# add a question
q1 = Question(text="What is this?")
doc.questions.append(q1)
Expand Down
12 changes: 8 additions & 4 deletions tests/models/test_extractive_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytorch_lightning import Trainer

from pie_modules.annotations import ExtractiveAnswer, Question
from pie_modules.documents import ExtractiveQADocument
from pie_modules.documents import TextDocumentWithQuestionsAndExtractiveAnswers
from pie_modules.models.simple_extractive_question_answering import (
SimpleExtractiveQuestionAnsweringModel,
)
Expand All @@ -20,15 +20,19 @@

@pytest.fixture
def documents():
document0 = ExtractiveQADocument(text="This is a test document", id="doc0")
document0 = TextDocumentWithQuestionsAndExtractiveAnswers(
text="This is a test document", id="doc0"
)
document0.questions.append(Question(text="What is the first word?"))
document0.answers.append(ExtractiveAnswer(question=document0.questions[0], start=0, end=3))

document1 = ExtractiveQADocument(text="Oranges are orange in color.", id="doc1")
document1 = TextDocumentWithQuestionsAndExtractiveAnswers(
text="Oranges are orange in color.", id="doc1"
)
document1.questions.append(Question(text="What color are oranges?"))
document1.answers.append(ExtractiveAnswer(question=document1.questions[0], start=23, end=27))

document2 = ExtractiveQADocument(
document2 = TextDocumentWithQuestionsAndExtractiveAnswers(
text="This is a test document that has two questions attached to it.", id="doc2"
)
document2.questions.append(Question(text="What type of document is this?"))
Expand Down
16 changes: 11 additions & 5 deletions tests/taskmodules/test_extractive_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
from pytorch_ie.core import AnnotationList

from pie_modules.annotations import ExtractiveAnswer, Question
from pie_modules.documents import ExtractiveQADocument
from pie_modules.documents import TextDocumentWithQuestionsAndExtractiveAnswers
from pie_modules.taskmodules.extractive_question_answering import (
ExtractiveQuestionAnsweringTaskModule,
)


@pytest.fixture()
def document():
document = ExtractiveQADocument(text="This is a test document", id="doc0")
document = TextDocumentWithQuestionsAndExtractiveAnswers(
text="This is a test document", id="doc0"
)
document.questions.append(Question(text="What is the first word?"))
document.answers.append(ExtractiveAnswer(question=document.questions[0], start=0, end=4))
assert str(document.answers[0]) == "This"
Expand All @@ -21,7 +23,9 @@ def document():

@pytest.fixture()
def document1():
document1 = ExtractiveQADocument(text="This is the second document", id="doc1")
document1 = TextDocumentWithQuestionsAndExtractiveAnswers(
text="This is the second document", id="doc1"
)
document1.questions.append(Question(text="Which document is this?"))
document1.answers.append(ExtractiveAnswer(question=document1.questions[0], start=13, end=18))
assert str(document1.answers[0]) == "second"
Expand All @@ -30,14 +34,16 @@ def document1():

@pytest.fixture()
def document_with_no_answer():
document = ExtractiveQADocument(text="This is a test document", id="document_with_no_answer")
document = TextDocumentWithQuestionsAndExtractiveAnswers(
text="This is a test document", id="document_with_no_answer"
)
document.questions.append(Question(text="What is the first word?"))
return document


@pytest.fixture()
def document_with_multiple_answers():
document = ExtractiveQADocument(
document = TextDocumentWithQuestionsAndExtractiveAnswers(
text="This is a test document", id="document_with_multiple_answers"
)
document.questions.append(Question(text="What is the first word?"))
Expand Down
Loading

0 comments on commit e85a98b

Please sign in to comment.