Skip to content

Commit

Permalink
Merge pull request #41 from ArneBinder/exportable_brat_document_types
Browse files Browse the repository at this point in the history
make BRAT document types exportable
  • Loading branch information
ArneBinder authored Nov 8, 2023
2 parents 2f91405 + 1d508de commit 49c8c29
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 31 deletions.
35 changes: 7 additions & 28 deletions dataset_builders/pie/brat/brat.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import dataclasses
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union

import datasets
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
from pytorch_ie.core import Annotation, AnnotationList, annotation_field
from pytorch_ie.documents import TextBasedDocument
from pytorch_ie.core import Annotation

from pie_datasets import GeneratorBasedBuilder
from pie_datasets.document.types import (
Attribute,
BratDocument,
BratDocumentWithMergedSpans,
)

logger = logging.getLogger(__name__)

Expand All @@ -24,33 +27,9 @@ def ld2dl(
return {k: [dic[k] for dic in list_fo_dicts] for k in keys}


@dataclasses.dataclass(eq=True, frozen=True)
class Attribute(Annotation):
annotation: Annotation
label: str
value: Optional[str] = None
score: Optional[float] = dataclasses.field(default=None, compare=False)


@dataclasses.dataclass
class BratDocument(TextBasedDocument):
spans: AnnotationList[LabeledMultiSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="spans")
span_attributes: AnnotationList[Attribute] = annotation_field(target="spans")
relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations")


@dataclasses.dataclass
class BratDocumentWithMergedSpans(TextBasedDocument):
spans: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="spans")
span_attributes: AnnotationList[Attribute] = annotation_field(target="spans")
relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations")


def example_to_document(
example: Dict[str, Any], merge_fragmented_spans: bool = False
) -> BratDocument:
) -> Union[BratDocument, BratDocumentWithMergedSpans]:
if merge_fragmented_spans:
doc = BratDocumentWithMergedSpans(text=example["context"], id=example["file_name"])
else:
Expand Down
30 changes: 30 additions & 0 deletions src/pie_datasets/document/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import dataclasses
from typing import Optional

from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
from pytorch_ie.core import Annotation, AnnotationList, annotation_field
from pytorch_ie.documents import TextBasedDocument


@dataclasses.dataclass(eq=True, frozen=True)
class Attribute(Annotation):
annotation: Annotation
label: str
value: Optional[str] = None
score: Optional[float] = dataclasses.field(default=None, compare=False)


@dataclasses.dataclass
class BratDocument(TextBasedDocument):
spans: AnnotationList[LabeledMultiSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="spans")
span_attributes: AnnotationList[Attribute] = annotation_field(target="spans")
relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations")


@dataclasses.dataclass
class BratDocumentWithMergedSpans(TextBasedDocument):
spans: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="spans")
span_attributes: AnnotationList[Attribute] = annotation_field(target="spans")
relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations")
9 changes: 6 additions & 3 deletions tests/dataset_builders/pie/test_brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@

from dataset_builders.pie.brat.brat import (
BratDatasetLoader,
BratDocument,
BratDocumentWithMergedSpans,
document_to_example,
example_to_document,
)
from pie_datasets.document.types import (
Attribute,
BratDocument,
BratDocumentWithMergedSpans,
)
from tests.dataset_builders.common import PIE_BASE_PATH, PIE_DS_FIXTURE_DATA_PATH

datasets.disable_caching()
Expand Down Expand Up @@ -40,7 +43,7 @@ def resolve_annotation(annotation: Annotation) -> Any:
annotation.label,
resolve_annotation(annotation.tail),
)
elif isinstance(annotation, Annotation) and str(type(annotation)).endswith("brat.Attribute'>"):
elif isinstance(annotation, Attribute):
result = (resolve_annotation(annotation.annotation), annotation.label)
if annotation.value is not None:
return result + (annotation.value,)
Expand Down

0 comments on commit 49c8c29

Please sign in to comment.