Skip to content

Commit

Permalink
Merge pull request #81 from ArneBinder/remove-document-types
Browse files Browse the repository at this point in the history
move document types to `tests.dataset_builders.common`
  • Loading branch information
ArneBinder authored Nov 27, 2023
2 parents 0c2813e + 931122d commit 4310e2f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 25 deletions.
15 changes: 0 additions & 15 deletions src/pie_datasets/document/types.py

This file was deleted.

15 changes: 15 additions & 0 deletions tests/dataset_builders/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import dataclasses
import json
import logging
import os
import re
from pathlib import Path
from typing import List, Optional

from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.documents import TokenBasedDocument

from tests import FIXTURES_ROOT

DATASET_BUILDER_BASE_PATH = Path("dataset_builders")
Expand Down Expand Up @@ -68,3 +73,13 @@ def _load_json(fn: str):
with open(fn) as f:
ex = json.load(f)
return ex


@dataclasses.dataclass
class TestTokenDocumentWithLabeledSpans(TokenBasedDocument):
labeled_spans: AnnotationList[LabeledSpan] = annotation_field(target="tokens")


@dataclasses.dataclass
class TestTokenDocumentWithLabeledSpansAndBinaryRelations(TestTokenDocumentWithLabeledSpans):
binary_relations: AnnotationList[BinaryRelation] = annotation_field(target="labeled_spans")
13 changes: 8 additions & 5 deletions tests/dataset_builders/pie/test_cdcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
example_to_document,
)
from pie_datasets import DatasetDict
from pie_datasets.document.types import TokenDocumentWithLabeledSpansAndBinaryRelations
from tests import FIXTURES_ROOT
from tests.dataset_builders.common import PIE_BASE_PATH, _deep_compare
from tests.dataset_builders.common import (
PIE_BASE_PATH,
TestTokenDocumentWithLabeledSpansAndBinaryRelations,
_deep_compare,
)

disable_caching()

Expand Down Expand Up @@ -306,7 +309,7 @@ def tokenizer() -> PreTrainedTokenizer:
@pytest.fixture(scope="module")
def tokenized_documents_with_labeled_spans_and_binary_relations(
dataset_of_text_documents_with_labeled_spans_and_binary_relations, tokenizer
) -> List[TokenDocumentWithLabeledSpansAndBinaryRelations]:
) -> List[TestTokenDocumentWithLabeledSpansAndBinaryRelations]:
# get a document to check
doc = dataset_of_text_documents_with_labeled_spans_and_binary_relations["train"][0]
# Note, that this is a list of documents, because the document may be split into chunks
Expand All @@ -315,7 +318,7 @@ def tokenized_documents_with_labeled_spans_and_binary_relations(
doc,
tokenizer=tokenizer,
return_overflowing_tokens=True,
result_document_type=TokenDocumentWithLabeledSpansAndBinaryRelations,
result_document_type=TestTokenDocumentWithLabeledSpansAndBinaryRelations,
verbose=True,
)
return tokenized_docs
Expand Down Expand Up @@ -433,7 +436,7 @@ def test_tokenized_documents_with_entities_and_relations_all(
doc,
tokenizer=tokenizer,
return_overflowing_tokens=True,
result_document_type=TokenDocumentWithLabeledSpansAndBinaryRelations,
result_document_type=TestTokenDocumentWithLabeledSpansAndBinaryRelations,
verbose=True,
)
# we just ensure that we get at least one tokenized document
Expand Down
13 changes: 8 additions & 5 deletions tests/dataset_builders/pie/test_scidtb_argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
example_to_document,
)
from pie_datasets import DatasetDict
from pie_datasets.document.types import TokenDocumentWithLabeledSpansAndBinaryRelations
from tests import FIXTURES_ROOT
from tests.dataset_builders.common import HF_DS_FIXTURE_DATA_PATH, PIE_BASE_PATH
from tests.dataset_builders.common import (
HF_DS_FIXTURE_DATA_PATH,
PIE_BASE_PATH,
TestTokenDocumentWithLabeledSpansAndBinaryRelations,
)

disable_caching()

Expand Down Expand Up @@ -292,7 +295,7 @@ def tokenizer() -> PreTrainedTokenizer:
@pytest.fixture(scope="module")
def tokenized_documents_with_labeled_spans_and_binary_relations(
dataset_of_text_documents_with_labeled_spans_and_binary_relations, tokenizer
) -> List[TokenDocumentWithLabeledSpansAndBinaryRelations]:
) -> List[TestTokenDocumentWithLabeledSpansAndBinaryRelations]:
# get a document to check
doc = dataset_of_text_documents_with_labeled_spans_and_binary_relations["train"][0]
# Note, that this is a list of documents, because the document may be split into chunks
Expand All @@ -301,7 +304,7 @@ def tokenized_documents_with_labeled_spans_and_binary_relations(
doc,
tokenizer=tokenizer,
return_overflowing_tokens=True,
result_document_type=TokenDocumentWithLabeledSpansAndBinaryRelations,
result_document_type=TestTokenDocumentWithLabeledSpansAndBinaryRelations,
verbose=True,
)
return tokenized_docs
Expand Down Expand Up @@ -373,7 +376,7 @@ def test_tokenized_documents_with_entities_and_relations_all(
doc,
tokenizer=tokenizer,
return_overflowing_tokens=True,
result_document_type=TokenDocumentWithLabeledSpansAndBinaryRelations,
result_document_type=TestTokenDocumentWithLabeledSpansAndBinaryRelations,
verbose=True,
)
# we just ensure that we get at least one tokenized document
Expand Down

0 comments on commit 4310e2f

Please sign in to comment.