Skip to content

Commit

Permalink
generalize
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Feb 9, 2024
1 parent ec19386 commit 0afff63
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 106 deletions.
113 changes: 49 additions & 64 deletions src/serializer/brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
from pie_modules.documents import TextBasedDocument
from pytorch_ie import AnnotationLayer
from pytorch_ie.core import Document
from pytorch_ie.core.document import Annotation, BaseAnnotationList
from pytorch_ie.core import Annotation, AnnotationLayer, Document
from pytorch_ie.utils.hydra import serialize_document_type

from src.serializer.interface import DocumentSerializer
Expand Down Expand Up @@ -126,8 +124,9 @@ def serialize_annotations(
annotation2id: Dict[Annotation, str],
label_prefix: Optional[str] = None,
first_idx: int = 0,
) -> List[str]:
) -> Tuple[List[str], Dict[Annotation, str]]:
serialized_annotations = []
new_annotation2id: Dict[Annotation, str] = {}
for i, annotation in enumerate(annotations, start=first_idx):
annotation_id, serialized_annotation = serialize_annotation(
idx=i,
Expand All @@ -136,22 +135,20 @@ def serialize_annotations(
label_prefix=label_prefix,
)
serialized_annotations.append(f"{annotation_id}\t{serialized_annotation}")
annotation2id[annotation] = annotation_id
new_annotation2id[annotation] = annotation_id

return serialized_annotations
return serialized_annotations, new_annotation2id


def serialize_annotation_layers(
labeled_span_layer: AnnotationLayer,
binary_relation_layer: AnnotationLayer,
layers: List[AnnotationLayer],
gold_label_prefix: Optional[str] = None,
prediction_label_prefix: Optional[str] = None,
) -> List[str]:
"""Serialize annotations including labeled spans and binary relations into a list of strings.
"""Serialize annotations from given annotation layers into a list of strings.
Args:
labeled_span_layer (AnnotationLayer): Annotation layer containing labeled spans.
binary_relation_layer (AnnotationLayer): Annotation layer containing binary relations.
layers (List[AnnotationLayer]): Annotation layers to be serialized.
gold_label_prefix (Optional[str], optional): Prefix to be added to gold labels.
Defaults to None.
prediction_label_prefix (Optional[str], optional): Prefix to be added to prediction labels.
Expand All @@ -160,70 +157,61 @@ def serialize_annotation_layers(
Returns:
List[str]: List of serialized annotations.
"""
serialized_labeled_spans = []
serialized_binary_relations = []
annotation2id: Dict[Annotation, str] = {}
if gold_label_prefix is not None:
serialized_labeled_spans_gold = serialize_annotations(
annotations=labeled_span_layer,
label_prefix=gold_label_prefix,
annotation2id=annotation2id,
)
serialized_labeled_spans.extend(serialized_labeled_spans_gold)
serialized_binary_relations_gold = serialize_annotations(
annotations=binary_relation_layer,
annotation2id=annotation2id,
label_prefix=gold_label_prefix,
all_serialized_annotations = []
gold_annotation2id: Dict[Annotation, str] = {}
prediction_annotation2id: Dict[Annotation, str] = {}
for layer in layers:
serialized_annotations = []
if gold_label_prefix is not None:
serialized_gold_annotations, new_gold_ann2id = serialize_annotations(
annotations=layer,
label_prefix=gold_label_prefix,
# gold annotations can only reference other gold annotations
annotation2id=gold_annotation2id,
)
serialized_annotations.extend(serialized_gold_annotations)
gold_annotation2id.update(new_gold_ann2id)
serialized_predicted_annotations, new_pred_ann2id = serialize_annotations(
annotations=layer.predictions,
label_prefix=prediction_label_prefix,
first_idx=len(serialized_annotations),
# predicted annotations can reference both gold and predicted annotations
annotation2id={**gold_annotation2id, **prediction_annotation2id},
)
serialized_binary_relations.extend(serialized_binary_relations_gold)
else:
annotation2id = {}
serialized_labeled_spans_pred = serialize_annotations(
annotations=labeled_span_layer.predictions,
label_prefix=prediction_label_prefix,
first_idx=len(serialized_labeled_spans),
annotation2id=annotation2id,
)
serialized_labeled_spans.extend(serialized_labeled_spans_pred)
serialized_binary_relations_pred = serialize_annotations(
annotations=binary_relation_layer.predictions,
annotation2id=annotation2id,
label_prefix=prediction_label_prefix,
first_idx=len(serialized_binary_relations),
)
serialized_binary_relations.extend(serialized_binary_relations_pred)
return serialized_labeled_spans + serialized_binary_relations
prediction_annotation2id.update(new_pred_ann2id)
serialized_annotations.extend(serialized_predicted_annotations)
all_serialized_annotations.extend(serialized_annotations)
return all_serialized_annotations


class BratSerializer(DocumentSerializer):
"""BratSerializer serialize documents into the Brat format. It requires "entity_layer" and
"relation_layer" parameters which defines the entity and relation annotation layer names. If
document processor is provided then these parameters must align with the respective entity and
relation annotation layer of resulting document from the document processor. BratSerializer
additionally provides the functionality to include both gold and predicted annotations in the
resulting annotation file, with the option to differentiate them using the label_prefix.
"""BratSerializer serialize documents into the Brat format. It requires a "layers" parameter to
specify the annotation layers to serialize.
If a gold_label_prefix is provided, the gold annotations are serialized with the given prefix.
Otherwise, only the predicted annotations are serialized. A document_processor can be provided
to process documents before serialization.
Attributes:
entity_layer: The name of the entity annotation layer.
relation_layer: The name of the relation annotation layer.
layers: The names of the annotation layers to serialize.
document_processor: A function or callable object to process documents before serialization.
prediction_label_prefix: An optional prefix for labels in predicted annotations.
gold_label_prefix: An optional prefix for labels in gold annotations.
gold_label_prefix: If provided, gold annotations are serialized and its labels are prefixed
with the given string. Otherwise, only predicted annotations are serialized.
prediction_label_prefix: If provided, labels of predicted annotations are prefixed with the
given string.
default_kwargs: Additional keyword arguments to be used as defaults during serialization.
"""

def __init__(
self,
entity_layer,
relation_layer,
layers: List[str],
document_processor=None,
prediction_label_prefix=None,
gold_label_prefix=None,
**kwargs,
):
self.document_processor = document_processor
self.entity_layer = entity_layer
self.relation_layer = relation_layer
self.layers = layers
self.prediction_label_prefix = prediction_label_prefix
self.gold_label_prefix = gold_label_prefix
self.default_kwargs = kwargs
Expand All @@ -233,8 +221,7 @@ def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]:
documents = list(map(self.document_processor, documents))
return self.write_with_defaults(
documents=documents,
entity_layer=self.entity_layer,
relation_layer=self.relation_layer,
layers=self.layers,
prediction_label_prefix=self.prediction_label_prefix,
gold_label_prefix=self.gold_label_prefix,
**kwargs,
Expand All @@ -248,8 +235,7 @@ def write_with_defaults(self, **kwargs) -> Dict[str, str]:
def write(
cls,
documents: Sequence[Document],
entity_layer: str,
relation_layer: str,
layers: List[str],
path: str,
metadata_file_name: str = METADATA_FILE_NAME,
split: Optional[str] = None,
Expand All @@ -272,7 +258,7 @@ def write(
os.makedirs(realpath, exist_ok=True)
metadata_text = defaultdict(str)
for i, doc in enumerate(documents):
doc_id = doc.id or f"doc_{i}"
doc_id = getattr(doc, "id", None) or f"doc_{i}"
if not isinstance(doc, TextBasedDocument):
raise TypeError(
f"Document {doc_id} has unexpected type: {type(doc)}. "
Expand All @@ -282,8 +268,7 @@ def write(
metadata_text[f"{file_name}"] = doc.text
ann_path = os.path.join(realpath, file_name)
serialized_annotations = serialize_annotation_layers(
labeled_span_layer=doc[entity_layer],
binary_relation_layer=doc[relation_layer],
layers=[doc[layer] for layer in layers],
gold_label_prefix=gold_label_prefix,
prediction_label_prefix=prediction_label_prefix,
)
Expand Down
60 changes: 18 additions & 42 deletions tests/unit/serializer/test_brat.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,29 @@
import dataclasses
import os
from dataclasses import dataclass
from typing import TypeVar

import pytest
from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
from pie_modules.documents import (
TextBasedDocument,
TextDocumentWithLabeledSpansAndBinaryRelations,
TokenDocumentWithLabeledSpansAndBinaryRelations,
)
from pytorch_ie import Annotation, AnnotationLayer
from pytorch_ie.core import Document, annotation_field
from pie_modules.documents import TextBasedDocument
from pytorch_ie import Annotation, AnnotationLayer, Document
from pytorch_ie.core import annotation_field

from src.serializer import BratSerializer
from src.serializer.brat import (
serialize_annotation,
serialize_annotation_layers,
serialize_binary_relation,
serialize_labeled_multi_span,
serialize_labeled_span,
)
from src.utils import get_pylogger

log = get_pylogger(__name__)
D = TypeVar("D", bound=Document)

@dataclasses.dataclass
class TextDocumentWithLabeledSpansAndBinaryRelations(TextBasedDocument):
labeled_spans: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(target="labeled_spans")


def test_serialize_labeled_span():
# labeled span

document = TextDocumentWithLabeledSpansAndBinaryRelations(
text="Harry lives in Berlin, Germany. He works at DFKI.", id="tmp"
)
Expand All @@ -48,8 +43,6 @@ def test_serialize_labeled_span():


def test_serialize_labeled_multi_span():

# labeled multi span
@dataclasses.dataclass
class TextDocumentWithLabeledMultiSpansAndBinaryRelations(TextBasedDocument):
labeled_multi_spans: AnnotationLayer[LabeledMultiSpan] = annotation_field(target="text")
Expand Down Expand Up @@ -173,8 +166,7 @@ def serialized_annotations(
prediction_label_prefix=None,
):
return serialize_annotation_layers(
labeled_span_layer=document.labeled_spans,
binary_relation_layer=document.binary_relations,
layers=[document.labeled_spans, document.binary_relations],
gold_label_prefix=gold_label_prefix,
prediction_label_prefix=prediction_label_prefix,
)
Expand All @@ -186,8 +178,7 @@ def serialized_annotations(
)
def test_serialize_annotations(document, gold_label_prefix, prediction_label_prefix):
serialized_annotations = serialize_annotation_layers(
labeled_span_layer=document.labeled_spans,
binary_relation_layer=document.binary_relations,
layers=[document.labeled_spans, document.binary_relations],
gold_label_prefix=gold_label_prefix,
prediction_label_prefix=prediction_label_prefix,
)
Expand Down Expand Up @@ -243,8 +234,7 @@ def test_write(tmp_path, document, serialized_annotations):
serializer = BratSerializer(
path=path,
document_processor=document_processor,
entity_layer="labeled_spans",
relation_layer="binary_relations",
layers=["labeled_spans", "binary_relations"],
)

metadata = serializer(documents=[document])
Expand All @@ -257,22 +247,9 @@ def test_write(tmp_path, document, serialized_annotations):
file.close()


@pytest.fixture
def dummy_document():
document = TokenDocumentWithLabeledSpansAndBinaryRelations(
id="tmp_1",
tokens=(),
)
return document


def test_write_with_exceptions_and_warnings(tmp_path, caplog, document, dummy_document):
def test_write_with_exceptions_and_warnings(tmp_path, caplog, document):
path = str(tmp_path)
serializer = BratSerializer(
path=path,
entity_layer="labeled_spans",
relation_layer="binary_relations",
)
serializer = BratSerializer(path=path, layers=["labeled_spans", "binary_relations"])

# list of empty documents
with pytest.raises(Exception) as e:
Expand All @@ -281,11 +258,10 @@ def test_write_with_exceptions_and_warnings(tmp_path, caplog, document, dummy_do

# List of documents with type unexpected Document type
with pytest.raises(TypeError) as type_error:
serializer(documents=[dummy_document])
serializer(documents=[Document()])
assert str(type_error.value) == (
"Document tmp_1 has unexpected type: <class "
"'pie_modules.documents.TokenDocumentWithLabeledSpansAndBinaryRelations'>. BratSerializer "
"can only serialize TextBasedDocuments."
"Document doc_0 has unexpected type: <class 'pytorch_ie.core.document.Document'>. "
"BratSerializer can only serialize TextBasedDocuments."
)

# Warning when metadata file already exists
Expand All @@ -304,7 +280,7 @@ def test_write_with_exceptions_and_warnings(tmp_path, caplog, document, dummy_do
def test_write_with_split(tmp_path, document, split):
path = str(tmp_path)
serializer = BratSerializer(
path=path, entity_layer="labeled_spans", relation_layer="binary_relations", split=split
path=path, layers=["labeled_spans", "binary_relations"], split=split
)

metadata = serializer(documents=[document])
Expand Down

0 comments on commit 0afff63

Please sign in to comment.