diff --git a/src/pie_datasets/document/types.py b/src/pie_datasets/document/types.py deleted file mode 100644 index 0ea32ed5..00000000 --- a/src/pie_datasets/document/types.py +++ /dev/null @@ -1,15 +0,0 @@ -import dataclasses - -from pytorch_ie.annotations import BinaryRelation, LabeledSpan -from pytorch_ie.core import AnnotationList, annotation_field -from pytorch_ie.documents import TokenBasedDocument - - -@dataclasses.dataclass -class TokenDocumentWithLabeledSpans(TokenBasedDocument): - labeled_spans: AnnotationList[LabeledSpan] = annotation_field(target="tokens") - - -@dataclasses.dataclass -class TokenDocumentWithLabeledSpansAndBinaryRelations(TokenDocumentWithLabeledSpans): - binary_relations: AnnotationList[BinaryRelation] = annotation_field(target="labeled_spans") diff --git a/tests/dataset_builders/common.py b/tests/dataset_builders/common.py index 70af75a7..e04169c4 100644 --- a/tests/dataset_builders/common.py +++ b/tests/dataset_builders/common.py @@ -1,3 +1,4 @@ +import dataclasses import json import logging import os @@ -5,6 +6,10 @@ from pathlib import Path from typing import List, Optional +from pytorch_ie.annotations import BinaryRelation, LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.documents import TokenBasedDocument + from tests import FIXTURES_ROOT DATASET_BUILDER_BASE_PATH = Path("dataset_builders") @@ -68,3 +73,13 @@ def _load_json(fn: str): with open(fn) as f: ex = json.load(f) return ex + + +@dataclasses.dataclass +class TestTokenDocumentWithLabeledSpans(TokenBasedDocument): + labeled_spans: AnnotationList[LabeledSpan] = annotation_field(target="tokens") + + +@dataclasses.dataclass +class TestTokenDocumentWithLabeledSpansAndBinaryRelations(TestTokenDocumentWithLabeledSpans): + binary_relations: AnnotationList[BinaryRelation] = annotation_field(target="labeled_spans") diff --git a/tests/dataset_builders/pie/test_cdcp.py b/tests/dataset_builders/pie/test_cdcp.py index 86e6d3c1..643d7b6f 100644 --- a/tests/dataset_builders/pie/test_cdcp.py +++ b/tests/dataset_builders/pie/test_cdcp.py @@ -20,9 +20,12 @@ example_to_document, ) from pie_datasets import DatasetDict -from pie_datasets.document.types import TokenDocumentWithLabeledSpansAndBinaryRelations from tests import FIXTURES_ROOT -from tests.dataset_builders.common import PIE_BASE_PATH, _deep_compare +from tests.dataset_builders.common import ( + PIE_BASE_PATH, + TestTokenDocumentWithLabeledSpansAndBinaryRelations, + _deep_compare, +) disable_caching() @@ -306,7 +309,7 @@ def tokenizer() -> PreTrainedTokenizer: @pytest.fixture(scope="module") def tokenized_documents_with_labeled_spans_and_binary_relations( dataset_of_text_documents_with_labeled_spans_and_binary_relations, tokenizer -) -> List[TokenDocumentWithLabeledSpansAndBinaryRelations]: +) -> List[TestTokenDocumentWithLabeledSpansAndBinaryRelations]: # get a document to check doc = dataset_of_text_documents_with_labeled_spans_and_binary_relations["train"][0] # Note, that this is a list of documents, because the document may be split into chunks @@ -315,7 +318,7 @@ def tokenized_documents_with_labeled_spans_and_binary_relations( doc, tokenizer=tokenizer, return_overflowing_tokens=True, - result_document_type=TokenDocumentWithLabeledSpansAndBinaryRelations, + result_document_type=TestTokenDocumentWithLabeledSpansAndBinaryRelations, verbose=True, ) return tokenized_docs @@ -433,7 +436,7 @@ def test_tokenized_documents_with_entities_and_relations_all( doc, tokenizer=tokenizer, return_overflowing_tokens=True, - result_document_type=TokenDocumentWithLabeledSpansAndBinaryRelations, + result_document_type=TestTokenDocumentWithLabeledSpansAndBinaryRelations, verbose=True, ) # we just ensure that we get at least one tokenized document diff --git a/tests/dataset_builders/pie/test_scidtb_argmin.py b/tests/dataset_builders/pie/test_scidtb_argmin.py index b7bced97..f26eda63 100644 --- a/tests/dataset_builders/pie/test_scidtb_argmin.py +++ b/tests/dataset_builders/pie/test_scidtb_argmin.py @@ -16,9 +16,12 @@ example_to_document, ) from pie_datasets import DatasetDict -from pie_datasets.document.types import TokenDocumentWithLabeledSpansAndBinaryRelations from tests import FIXTURES_ROOT -from tests.dataset_builders.common import HF_DS_FIXTURE_DATA_PATH, PIE_BASE_PATH +from tests.dataset_builders.common import ( + HF_DS_FIXTURE_DATA_PATH, + PIE_BASE_PATH, + TestTokenDocumentWithLabeledSpansAndBinaryRelations, +) disable_caching() @@ -292,7 +295,7 @@ def tokenizer() -> PreTrainedTokenizer: @pytest.fixture(scope="module") def tokenized_documents_with_labeled_spans_and_binary_relations( dataset_of_text_documents_with_labeled_spans_and_binary_relations, tokenizer -) -> List[TokenDocumentWithLabeledSpansAndBinaryRelations]: +) -> List[TestTokenDocumentWithLabeledSpansAndBinaryRelations]: # get a document to check doc = dataset_of_text_documents_with_labeled_spans_and_binary_relations["train"][0] # Note, that this is a list of documents, because the document may be split into chunks @@ -301,7 +304,7 @@ def tokenized_documents_with_labeled_spans_and_binary_relations( doc, tokenizer=tokenizer, return_overflowing_tokens=True, - result_document_type=TokenDocumentWithLabeledSpansAndBinaryRelations, + result_document_type=TestTokenDocumentWithLabeledSpansAndBinaryRelations, verbose=True, ) return tokenized_docs @@ -373,7 +376,7 @@ def test_tokenized_documents_with_entities_and_relations_all( doc, tokenizer=tokenizer, return_overflowing_tokens=True, - result_document_type=TokenDocumentWithLabeledSpansAndBinaryRelations, + result_document_type=TestTokenDocumentWithLabeledSpansAndBinaryRelations, verbose=True, ) # we just ensure that we get at least one tokenized document