diff --git a/src/pie_modules/taskmodules/common/taskmodule_with_document_converter.py b/src/pie_modules/taskmodules/common/taskmodule_with_document_converter.py index 39c1dfdb2..6cb5b332c 100644 --- a/src/pie_modules/taskmodules/common/taskmodule_with_document_converter.py +++ b/src/pie_modules/taskmodules/common/taskmodule_with_document_converter.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Generic, Iterable, Iterator, Optional, Sequence, TypeVar, Union +from typing import Generic, Iterable, Iterator, Optional, Sequence, Type, TypeVar, Union from pytorch_ie import Document, TaskEncoding, TaskModule from pytorch_ie.core.taskmodule import ( @@ -7,7 +7,6 @@ TaskEncodingDataset, TaskEncodingSequence, ) -from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations from typing_extensions import TypeAlias DocumentType = TypeVar("DocumentType", bound=Document) @@ -40,6 +39,13 @@ class TaskModuleWithDocumentConverter( TaskOutputType, ], ): + @property + def document_type(self) -> Optional[Type[Document]]: + if super().document_type is not None: + raise NotImplementedError(f"please overwrite document_type for {type(self).__name__}") + else: + return None + @abstractmethod def _convert_document(self, document: DocumentType) -> ConvertedDocumentType: """Convert a document of the taskmodule document type to the expected document type of the @@ -100,7 +106,7 @@ def _integrate_predictions_from_converted_document( self, document: DocumentType, converted_document: ConvertedDocumentType, - ) -> TextDocumentWithLabeledSpansAndBinaryRelations: + ) -> None: """Convert the predictions at the respective layers of the converted_document and add them to the original document predictions.