Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement TaskModuleWithDocumentConverter #114

Merged
merged 9 commits into from
Sep 18, 2024
1 change: 1 addition & 0 deletions src/pie_modules/taskmodules/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .interfaces import AnnotationEncoderDecoder, DecodingException
from .mixins import BatchableMixin
from .taskmodule_with_document_converter import TaskModuleWithDocumentConverter
from .utils import get_first_occurrence_index
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from abc import ABC, abstractmethod
from typing import Generic, Iterable, Iterator, Optional, Sequence, Type, TypeVar, Union

from pytorch_ie import Document, TaskEncoding, TaskModule
from pytorch_ie.core.taskmodule import (
IterableTaskEncodingDataset,
TaskEncodingDataset,
TaskEncodingSequence,
)
from typing_extensions import TypeAlias

DocumentType = TypeVar("DocumentType", bound=Document)
ConvertedDocumentType = TypeVar("ConvertedDocumentType", bound=Document)
InputEncodingType = TypeVar("InputEncodingType")
TargetEncodingType = TypeVar("TargetEncodingType")
# TaskEncoding: defined below
TaskBatchEncodingType = TypeVar("TaskBatchEncodingType")
# ModelBatchEncoding: defined in models
ModelBatchOutputType = TypeVar("ModelBatchOutputType")
TaskOutputType = TypeVar("TaskOutputType")

TaskEncodingType: TypeAlias = TaskEncoding[
DocumentType,
InputEncodingType,
TargetEncodingType,
]


class TaskModuleWithDocumentConverter(
TaskModule,
ABC,
Generic[
ConvertedDocumentType,
DocumentType,
InputEncodingType,
TargetEncodingType,
TaskBatchEncodingType,
ModelBatchOutputType,
TaskOutputType,
],
):
@property
def document_type(self) -> Optional[Type[Document]]:
if super().document_type is not None:
raise NotImplementedError(f"please overwrite document_type for {type(self).__name__}")
else:
return None

@abstractmethod
def _convert_document(self, document: DocumentType) -> ConvertedDocumentType:
"""Convert a document of the taskmodule document type to the expected document type of the
wrapped taskmodule.

Args:
document: the input document

Returns: the converted document
"""
pass

def _prepare(self, documents: Sequence[DocumentType]) -> None:
# use an iterator for lazy processing
documents_converted = (self._convert_document(doc) for doc in documents)
super()._prepare(documents=documents_converted)

def convert_document(self, document: DocumentType) -> ConvertedDocumentType:
converted_document = self._convert_document(document)
if "original_document" in converted_document.metadata:
raise ValueError(
f"metadata of converted_document has already and entry 'original_document', "
f"this is not allowed. Please adjust '{type(self).__name__}._convert_document()' "
f"to produce documents without that key in metadata."
)
converted_document.metadata["original_document"] = document
return converted_document

def encode(
self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs
) -> Union[
Sequence[TaskEncodingType],
TaskEncodingSequence[TaskEncodingType, DocumentType],
Iterator[TaskEncodingType],
TaskEncodingDataset[TaskEncodingType],
IterableTaskEncodingDataset[TaskEncodingType],
]:
converted_documents: Union[DocumentType, Iterable[DocumentType]]
if isinstance(documents, Document):
converted_documents = self.convert_document(documents)
else:
converted_documents = [self.convert_document(doc) for doc in documents]
return super().encode(documents=converted_documents, **kwargs)

def decode(self, **kwargs) -> Sequence[DocumentType]:
decoded_documents = super().decode(**kwargs)
result = []
for doc in decoded_documents:
original_document = doc.metadata["original_document"]
self._integrate_predictions_from_converted_document(
converted_document=doc, document=original_document
)
result.append(original_document)
return result

@abstractmethod
def _integrate_predictions_from_converted_document(
self,
document: DocumentType,
converted_document: ConvertedDocumentType,
) -> None:
"""Convert the predictions at the respective layers of the converted_document and add them
to the original document predictions.

Args:
document: document to attach the converted predictions to
converted_document: the document returned by the wrapped taskmodule, including predictions
"""
pass
163 changes: 163 additions & 0 deletions tests/taskmodules/common/test_taskmodule_with_document_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from typing import Optional, Type

import pytest
from pytorch_ie import Document
from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations
from typing_extensions import TypeAlias

from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule
from pie_modules.taskmodules.common import TaskModuleWithDocumentConverter
from tests.conftest import TestDocument

DocumentType: TypeAlias = TestDocument
ConvertedDocumentType: TypeAlias = TextDocumentWithLabeledSpansAndBinaryRelations


class MyRETaskModuleWithDocConverter(
TaskModuleWithDocumentConverter, RETextClassificationWithIndicesTaskModule
):
@property
def document_type(self) -> Optional[Type[Document]]:
return TestDocument

def _convert_document(self, document: DocumentType) -> ConvertedDocumentType:
result = document.as_type(
TextDocumentWithLabeledSpansAndBinaryRelations,
field_mapping={"entities": "labeled_spans", "relations": "binary_relations"},
)
new2old_span = {
new_s: old_s for old_s, new_s in zip(document.entities, result.labeled_spans)
}
result.metadata["new2old_span"] = new2old_span
return result

def _integrate_predictions_from_converted_document(
self, document: DocumentType, converted_document: ConvertedDocumentType
) -> None:
new2old_span = converted_document.metadata["new2old_span"]
for rel in converted_document.binary_relations.predictions:
new_rel = rel.copy(head=new2old_span[rel.head], tail=new2old_span[rel.tail])
document.relations.predictions.append(new_rel)


@pytest.fixture(scope="module")
def taskmodule(documents):
result = MyRETaskModuleWithDocConverter(tokenizer_name_or_path="bert-base-cased")
result.prepare(documents)
return result


def test_taskmodule(taskmodule):
assert taskmodule is not None
assert taskmodule.document_type == TestDocument


@pytest.fixture(scope="module")
def task_encodings(taskmodule, documents):
return taskmodule.encode(documents, encode_target=True)


def test_task_encodings(task_encodings):
assert len(task_encodings) == 7


def test_decode(taskmodule, task_encodings):
label_indices = [0, 1, 3, 0, 0, 2, 0]
probabilities = [0.1738, 0.6643, 0.2101, 0.0801, 0.0319, 0.81, 0.3079]
task_outputs = [
{"labels": [taskmodule.id_to_label[label_idx]], "probabilities": [prob]}
for label_idx, prob in zip(label_indices, probabilities)
]
docs_with_predictions = taskmodule.decode(
task_encodings=task_encodings, task_outputs=task_outputs
)
assert all(isinstance(doc, TestDocument) for doc in docs_with_predictions)

all_gold_relations = [doc.relations.resolve() for doc in docs_with_predictions]
assert all_gold_relations == [
[("per:employee_of", (("PER", "Entity A"), ("ORG", "B")))],
[
("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))),
("per:founder", (("PER", "Entity G"), ("ORG", "I"))),
("org:founded_by", (("ORG", "I"), ("ORG", "H"))),
],
[
("per:employee_of", (("PER", "Entity M"), ("ORG", "N"))),
("per:founder", (("PER", "it"), ("ORG", "O"))),
("org:founded_by", (("ORG", "O"), ("PER", "it"))),
],
]

all_predicted_relations = [
doc.relations.predictions.resolve() for doc in docs_with_predictions
]
assert all_predicted_relations == [
[("no_relation", (("PER", "Entity A"), ("ORG", "B")))],
[
("org:founded_by", (("PER", "Entity G"), ("ORG", "H"))),
("per:founder", (("PER", "Entity G"), ("ORG", "I"))),
("no_relation", (("ORG", "I"), ("ORG", "H"))),
],
[
("no_relation", (("PER", "Entity M"), ("ORG", "N"))),
("per:employee_of", (("PER", "it"), ("ORG", "O"))),
("no_relation", (("ORG", "O"), ("PER", "it"))),
],
]


class MyRETaskModuleWithDocConverterWithoutDocType(
TaskModuleWithDocumentConverter, RETextClassificationWithIndicesTaskModule
):
def _convert_document(self, document: DocumentType) -> ConvertedDocumentType:
pass

def _integrate_predictions_from_converted_document(
self, document: DocumentType, converted_document: ConvertedDocumentType
) -> None:
pass


def test_missing_document_type_overwrite():
taskmodule = MyRETaskModuleWithDocConverterWithoutDocType(
tokenizer_name_or_path="bert-base-cased"
)

with pytest.raises(NotImplementedError) as e:
taskmodule.document_type
assert (
str(e.value)
== "please overwrite document_type for MyRETaskModuleWithDocConverterWithoutDocType"
)


class MyRETaskModuleWithWrongDocConverter(
TaskModuleWithDocumentConverter, RETextClassificationWithIndicesTaskModule
):
@property
def document_type(self) -> Optional[Type[Document]]:
return TestDocument

def _convert_document(self, document: DocumentType) -> ConvertedDocumentType:
result = TextDocumentWithLabeledSpansAndBinaryRelations(text="dummy")
result.metadata["original_document"] = None
return result

def _integrate_predictions_from_converted_document(
self, document: DocumentType, converted_document: ConvertedDocumentType
) -> None:
pass


def test_wrong_doc_converter(documents):
taskmodule = MyRETaskModuleWithWrongDocConverter(tokenizer_name_or_path="bert-base-cased")
taskmodule.prepare(documents)
with pytest.raises(ValueError) as e:
taskmodule.encode(documents, encode_target=True)
assert (
str(e.value)
== "metadata of converted_document has already and entry 'original_document', "
"this is not allowed. Please adjust "
"'MyRETaskModuleWithWrongDocConverter._convert_document()' to produce "
"documents without that key in metadata."
)
Loading