Skip to content

Commit

Permalink
add requirement to overwrite document_type if super method returns a …
Browse files Browse the repository at this point in the history
…type
  • Loading branch information
ArneBinder committed Sep 18, 2024
1 parent b14d034 commit b531c05
Showing 1 changed file with 9 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from abc import ABC, abstractmethod
from typing import Generic, Iterable, Iterator, Optional, Sequence, TypeVar, Union
from typing import Generic, Iterable, Iterator, Optional, Sequence, Type, TypeVar, Union

from pytorch_ie import Document, TaskEncoding, TaskModule
from pytorch_ie.core.taskmodule import (
IterableTaskEncodingDataset,
TaskEncodingDataset,
TaskEncodingSequence,
)
from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations
from typing_extensions import TypeAlias

DocumentType = TypeVar("DocumentType", bound=Document)
Expand Down Expand Up @@ -40,6 +39,13 @@ class TaskModuleWithDocumentConverter(
TaskOutputType,
],
):
@property
def document_type(self) -> Optional[Type[Document]]:
if super().document_type is not None:
raise NotImplementedError(f"please overwrite document_type for {type(self).__name__}")
else:
return None

@abstractmethod
def _convert_document(self, document: DocumentType) -> ConvertedDocumentType:
"""Convert a document of the taskmodule document type to the expected document type of the
Expand Down Expand Up @@ -100,7 +106,7 @@ def _integrate_predictions_from_converted_document(
self,
document: DocumentType,
converted_document: ConvertedDocumentType,
) -> TextDocumentWithLabeledSpansAndBinaryRelations:
) -> None:
"""Convert the predictions at the respective layers of the converted_document and add them
to the original document predictions.
Expand Down

0 comments on commit b531c05

Please sign in to comment.