diff --git a/requirements.txt b/requirements.txt index 6994afa..34ab880 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # --------- pytorch-ie --------- # pytorch-ie>=0.28.0,<0.30.0 pie-datasets>=0.8.1,<0.9.0 -pie-modules>=0.9.0,<0.10.0 +pie-modules>=0.10.6,<0.11.0 # --------- hydra --------- # hydra-core>=1.3.0 diff --git a/src/serializer/__init__.py b/src/serializer/__init__.py index 159fea1..2dc3e4b 100644 --- a/src/serializer/__init__.py +++ b/src/serializer/__init__.py @@ -1 +1,2 @@ +from .brat import BratSerializer from .json import JsonSerializer diff --git a/src/serializer/brat.py b/src/serializer/brat.py new file mode 100644 index 0000000..01d6a59 --- /dev/null +++ b/src/serializer/brat.py @@ -0,0 +1,282 @@ +import json +import os +from collections import defaultdict +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union + +from pie_datasets.core.dataset_dict import METADATA_FILE_NAME +from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan +from pie_modules.documents import TextBasedDocument +from pytorch_ie.core import Annotation, AnnotationLayer, Document +from pytorch_ie.utils.hydra import serialize_document_type + +from src.serializer.interface import DocumentSerializer +from src.utils import get_pylogger + +log = get_pylogger(__name__) + +D = TypeVar("D", bound=Document) + + +def serialize_labeled_span( + annotation: LabeledSpan, + label_prefix: Optional[str] = None, +) -> Tuple[str, str]: + """Serialize a labeled span into a string representation. + + Args: + annotation (LabeledSpan): The labeled span object to serialize. + label_prefix (Optional[str], optional): A prefix to be added to the label. Defaults to None. + + Returns: + str: The annotation type and serialized representation of the labeled span. + """ + label = annotation.label if label_prefix is None else f"{label_prefix}-{annotation.label}" + start_idx = annotation.start + end_idx = annotation.end + entity_text = annotation.target[start_idx:end_idx] + serialized_labeled_span = f"{label} {start_idx} {end_idx}\t{entity_text}\n" + return "T", serialized_labeled_span + + +def serialize_labeled_multi_span( + annotation: LabeledMultiSpan, + label_prefix: Optional[str] = None, +) -> Tuple[str, str]: + """Serialize a labeled multi span into a string representation. + + Args: + annotation (LabeledMultiSpan): The labeled multi span object to serialize. + label_prefix (Optional[str], optional): A prefix to be added to the label. Defaults to None. + + Returns: + str: The annotation type and serialized representation of the labeled multi span. + """ + label = annotation.label if label_prefix is None else f"{label_prefix}-{annotation.label}" + + locations = [] + texts = [] + for slice in annotation.slices: + start_idx = slice[0] + end_idx = slice[1] + texts.append(annotation.target[start_idx:end_idx]) + locations.append(f"{start_idx} {end_idx}") + location = ";".join(locations) + text = " ".join(texts) + serialized_labeled_span = f"{label} {location}\t{text}\n" + return "T", serialized_labeled_span + + +def serialize_binary_relation( + annotation: BinaryRelation, + annotation2id: Dict[Annotation, str], + label_prefix: Optional[str] = None, +) -> Tuple[str, str]: + """Serialize a binary relation into a string representation. + + Args: + annotation (Union[LabeledMultiSpan, LabeledSpan]): The binary relation object to serialize. + Labeled Spans in the binary relation can be either a LabeledMultiSpan or a LabeledSpan. + annotation2id (Dict[Annotation, str]): A dictionary mapping span annotations to their IDs. + label_prefix (Optional[str], optional): A prefix to be added to the label. + Defaults to None. + + Returns: + str: The annotation type and serialized representation of the binary relation. + """ + + arg1 = annotation2id[annotation.head] + arg2 = annotation2id[annotation.tail] + label = annotation.label if label_prefix is None else f"{label_prefix}-{annotation.label}" + serialized_binary_relation = f"{label} Arg1:{arg1} Arg2:{arg2}\n" + return "R", serialized_binary_relation + + +def serialize_annotation( + annotation: Annotation, + annotation2id: Dict[Annotation, str], + label_prefix: Optional[str] = None, +) -> Tuple[str, str]: + if isinstance(annotation, LabeledMultiSpan): + return serialize_labeled_multi_span(annotation=annotation, label_prefix=label_prefix) + elif isinstance(annotation, LabeledSpan): + return serialize_labeled_span(annotation=annotation, label_prefix=label_prefix) + elif isinstance(annotation, BinaryRelation): + return serialize_binary_relation( + annotation=annotation, label_prefix=label_prefix, annotation2id=annotation2id + ) + else: + raise Warning(f"annotation has unknown type: {type(annotation)}") + + +def serialize_annotations( + annotations: Iterable[Annotation], + indices: Dict[str, int], + annotation2id: Dict[Annotation, str], + label_prefix: Optional[str] = None, +) -> Tuple[List[str], Dict[Annotation, str]]: + serialized_annotations = [] + new_annotation2id: Dict[Annotation, str] = {} + for annotation in annotations: + annotation_type, serialized_annotation = serialize_annotation( + annotation=annotation, + annotation2id=annotation2id, + label_prefix=label_prefix, + ) + idx = indices[annotation_type] + annotation_id = f"{annotation_type}{idx}" + serialized_annotations.append(f"{annotation_id}\t{serialized_annotation}") + new_annotation2id[annotation] = annotation_id + indices[annotation_type] += 1 + + return serialized_annotations, new_annotation2id + + +def serialize_annotation_layers( + layers: List[AnnotationLayer], + gold_label_prefix: Optional[str] = None, + prediction_label_prefix: Optional[str] = None, +) -> List[str]: + """Serialize annotations from given annotation layers into a list of strings. + + Args: + layers (List[AnnotationLayer]): Annotation layers to be serialized. + gold_label_prefix (Optional[str], optional): Prefix to be added to gold labels. + Defaults to None. + prediction_label_prefix (Optional[str], optional): Prefix to be added to prediction labels. + Defaults to None. + + Returns: + List[str]: List of serialized annotations. + """ + all_serialized_annotations = [] + gold_annotation2id: Dict[Annotation, str] = {} + prediction_annotation2id: Dict[Annotation, str] = {} + indices: Dict[str, int] = defaultdict(int) + for layer in layers: + serialized_annotations = [] + if gold_label_prefix is not None: + serialized_gold_annotations, new_gold_ann2id = serialize_annotations( + annotations=layer, + indices=indices, + # gold annotations can only reference other gold annotations + annotation2id=gold_annotation2id, + label_prefix=gold_label_prefix, + ) + serialized_annotations.extend(serialized_gold_annotations) + gold_annotation2id.update(new_gold_ann2id) + serialized_predicted_annotations, new_pred_ann2id = serialize_annotations( + annotations=layer.predictions, + indices=indices, + # Predicted annotations can reference both gold and predicted annotations. + # Note that predictions take precedence over gold annotations. + annotation2id={**gold_annotation2id, **prediction_annotation2id}, + label_prefix=prediction_label_prefix, + ) + prediction_annotation2id.update(new_pred_ann2id) + serialized_annotations.extend(serialized_predicted_annotations) + all_serialized_annotations.extend(serialized_annotations) + return all_serialized_annotations + + +class BratSerializer(DocumentSerializer): + """BratSerializer serialize documents into the Brat format. It requires a "layers" parameter to + specify the annotation layers to serialize. For now, it supports layers containing LabeledSpan, + LabeledMultiSpan, and BinaryRelation annotations. + + If a gold_label_prefix is provided, the gold annotations are serialized with the given prefix. + Otherwise, only the predicted annotations are serialized. A document_processor can be provided + to process documents before serialization. + + Attributes: + layers: The names of the annotation layers to serialize. + document_processor: A function or callable object to process documents before serialization. + gold_label_prefix: If provided, gold annotations are serialized and its labels are prefixed + with the given string. Otherwise, only predicted annotations are serialized. + prediction_label_prefix: If provided, labels of predicted annotations are prefixed with the + given string. + default_kwargs: Additional keyword arguments to be used as defaults during serialization. + """ + + def __init__( + self, + layers: List[str], + document_processor=None, + prediction_label_prefix=None, + gold_label_prefix=None, + **kwargs, + ): + self.document_processor = document_processor + self.layers = layers + self.prediction_label_prefix = prediction_label_prefix + self.gold_label_prefix = gold_label_prefix + self.default_kwargs = kwargs + + def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]: + if self.document_processor is not None: + documents = list(map(self.document_processor, documents)) + return self.write_with_defaults( + documents=documents, + layers=self.layers, + prediction_label_prefix=self.prediction_label_prefix, + gold_label_prefix=self.gold_label_prefix, + **kwargs, + ) + + def write_with_defaults(self, **kwargs) -> Dict[str, str]: + all_kwargs = {**self.default_kwargs, **kwargs} + return self.write(**all_kwargs) + + @classmethod + def write( + cls, + documents: Sequence[Document], + layers: List[str], + path: str, + metadata_file_name: str = METADATA_FILE_NAME, + split: Optional[str] = None, + gold_label_prefix: Optional[str] = None, + prediction_label_prefix: Optional[str] = None, + ) -> Dict[str, str]: + + realpath = os.path.realpath(path) + log.info(f'serialize documents to "{realpath}" ...') + os.makedirs(realpath, exist_ok=True) + + if len(documents) == 0: + raise Exception("cannot serialize empty list of documents") + document_type = type(documents[0]) + metadata = {"document_type": serialize_document_type(document_type)} + full_metadata_file_name = os.path.join(realpath, metadata_file_name) + + if split is not None: + realpath = os.path.join(realpath, split) + os.makedirs(realpath, exist_ok=True) + metadata_text = defaultdict(str) + for i, doc in enumerate(documents): + doc_id = getattr(doc, "id", None) or f"doc_{i}" + if not isinstance(doc, TextBasedDocument): + raise TypeError( + f"Document {doc_id} has unexpected type: {type(doc)}. " + "BratSerializer can only serialize TextBasedDocuments." + ) + file_name = f"{doc_id}.ann" + metadata_text[f"{file_name}"] = doc.text + ann_path = os.path.join(realpath, file_name) + serialized_annotations = serialize_annotation_layers( + layers=[doc[layer] for layer in layers], + gold_label_prefix=gold_label_prefix, + prediction_label_prefix=prediction_label_prefix, + ) + with open(ann_path, "w+") as f: + f.writelines(serialized_annotations) + + metadata["text"] = metadata_text + + if os.path.exists(full_metadata_file_name): + log.warning( + f"metadata file {full_metadata_file_name} already exists, " + "it will be overwritten!" + ) + with open(full_metadata_file_name, "w") as f: + json.dump(metadata, f, indent=2) + return {"path": realpath, "metadata_file_name": metadata_file_name} diff --git a/tests/unit/serializer/test_brat.py b/tests/unit/serializer/test_brat.py new file mode 100644 index 0000000..401c5b9 --- /dev/null +++ b/tests/unit/serializer/test_brat.py @@ -0,0 +1,283 @@ +import dataclasses +import os +from dataclasses import dataclass + +import pytest +from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan +from pie_modules.documents import TextBasedDocument +from pytorch_ie import Annotation, AnnotationLayer, Document +from pytorch_ie.core import annotation_field + +from src.serializer import BratSerializer +from src.serializer.brat import ( + serialize_annotation, + serialize_annotation_layers, + serialize_binary_relation, +) + + +@dataclasses.dataclass +class TextDocumentWithLabeledSpansAndBinaryRelations(TextBasedDocument): + labeled_spans: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(target="labeled_spans") + + +def test_serialize_labeled_span(): + + document = TextDocumentWithLabeledSpansAndBinaryRelations( + text="Harry lives in Berlin, Germany. He works at DFKI.", id="tmp" + ) + document.labeled_spans.extend( + [ + LabeledSpan(start=15, end=30, label="LOCATION"), + ] + ) + labeled_span = document.labeled_spans[0] + annotation_type, serialized_annotation = serialize_annotation( + annotation=labeled_span, + annotation2id={}, + ) + assert annotation_type == "T" + assert serialized_annotation == "LOCATION 15 30\tBerlin, Germany\n" + + +def test_serialize_labeled_multi_span(): + @dataclasses.dataclass + class TextDocumentWithLabeledMultiSpansAndBinaryRelations(TextBasedDocument): + labeled_multi_spans: AnnotationLayer[LabeledMultiSpan] = annotation_field(target="text") + binary_relations: AnnotationLayer[BinaryRelation] = annotation_field( + target="labeled_multi_spans" + ) + + document = TextDocumentWithLabeledMultiSpansAndBinaryRelations( + text="Harry lives in Berlin, Germany. He works at DFKI.", id="tmp" + ) + document.labeled_multi_spans.extend( + [ + LabeledMultiSpan(slices=((15, 21), (23, 30)), label="LOCATION"), + ] + ) + labeled_multi_span = document.labeled_multi_spans[0] + annotation_type, serialized_annotation = serialize_annotation( + annotation=labeled_multi_span, + annotation2id={}, + ) + assert annotation_type == "T" + assert serialized_annotation == "LOCATION 15 21;23 30\tBerlin Germany\n" + + +def test_serialize_binary_relation(): + binary_relation = BinaryRelation( + head=LabeledSpan(start=0, end=5, label="PERSON"), + tail=LabeledSpan(start=15, end=30, label="LOCATION"), + label="lives_in", + ) + span2id = {binary_relation.head: "T1", binary_relation.tail: "T2"} + annotation_type, serialized_binary_relation = serialize_binary_relation( + annotation=binary_relation, + annotation2id=span2id, + ) + assert annotation_type == "R" + assert serialized_binary_relation == "lives_in Arg1:T1 Arg2:T2\n" + + +def test_serialize_unknown_annotation(): + + with pytest.raises(Warning) as w: + serialize_annotation(annotation=Annotation(), annotation2id={}) + assert ( + str(w.value) + == "annotation has unknown type: " + ) + + +@dataclass +class TextDocumentWithLabeledMultiSpansAndBinaryRelations(TextBasedDocument): + labeled_spans: AnnotationLayer[LabeledMultiSpan] = annotation_field(target="text") + binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(target="labeled_spans") + + +@pytest.fixture +def document(): + document = TextDocumentWithLabeledSpansAndBinaryRelations( + text="Harry lives in Berlin, Germany. He works at DFKI.", id="tmp" + ) + document.labeled_spans.predictions.extend( + [ + LabeledSpan(start=0, end=5, label="PERSON"), + LabeledSpan(start=44, end=48, label="ORGANIZATION"), + ] + ) + + assert str(document.labeled_spans.predictions[0]) == "Harry" + assert str(document.labeled_spans.predictions[1]) == "DFKI" + + document.labeled_spans.extend( + [ + LabeledSpan(start=0, end=5, label="PERSON"), + LabeledSpan(start=15, end=30, label="LOCATION"), + LabeledSpan(start=44, end=48, label="ORGANIZATION"), + ] + ) + assert str(document.labeled_spans[0]) == "Harry" + assert str(document.labeled_spans[1]) == "Berlin, Germany" + assert str(document.labeled_spans[2]) == "DFKI" + + document.binary_relations.predictions.extend( + [ + BinaryRelation( + head=document.labeled_spans[0], + tail=document.labeled_spans[2], + label="works_at", + ), + ] + ) + + document.binary_relations.extend( + [ + BinaryRelation( + head=document.labeled_spans[0], + tail=document.labeled_spans[1], + label="lives_in", + ), + BinaryRelation( + head=document.labeled_spans[0], + tail=document.labeled_spans[2], + label="works_at", + ), + ] + ) + + return document + + +@pytest.fixture +def serialized_annotations( + document, + gold_label_prefix=None, + prediction_label_prefix=None, +): + return serialize_annotation_layers( + layers=[document.labeled_spans, document.binary_relations], + gold_label_prefix=gold_label_prefix, + prediction_label_prefix=prediction_label_prefix, + ) + + +@pytest.mark.parametrize( + "gold_label_prefix, prediction_label_prefix", + [(None, None), ("GOLD", None), (None, "PRED"), ("GOLD", "PRED")], +) +def test_serialize_annotations(document, gold_label_prefix, prediction_label_prefix): + serialized_annotations = serialize_annotation_layers( + layers=[document.labeled_spans, document.binary_relations], + gold_label_prefix=gold_label_prefix, + prediction_label_prefix=prediction_label_prefix, + ) + + if gold_label_prefix == "GOLD" and prediction_label_prefix == "PRED": + assert len(serialized_annotations) == 8 + assert serialized_annotations == [ + "T0\tGOLD-PERSON 0 5\tHarry\n", + "T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n", + "T2\tGOLD-ORGANIZATION 44 48\tDFKI\n", + "T3\tPRED-PERSON 0 5\tHarry\n", + "T4\tPRED-ORGANIZATION 44 48\tDFKI\n", + "R0\tGOLD-lives_in Arg1:T0 Arg2:T1\n", + "R1\tGOLD-works_at Arg1:T0 Arg2:T2\n", + "R2\tPRED-works_at Arg1:T3 Arg2:T4\n", + ] + elif gold_label_prefix == "GOLD" and prediction_label_prefix is None: + assert len(serialized_annotations) == 8 + assert serialized_annotations == [ + "T0\tGOLD-PERSON 0 5\tHarry\n", + "T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n", + "T2\tGOLD-ORGANIZATION 44 48\tDFKI\n", + "T3\tPERSON 0 5\tHarry\n", + "T4\tORGANIZATION 44 48\tDFKI\n", + "R0\tGOLD-lives_in Arg1:T0 Arg2:T1\n", + "R1\tGOLD-works_at Arg1:T0 Arg2:T2\n", + "R2\tworks_at Arg1:T3 Arg2:T4\n", + ] + elif gold_label_prefix is None and prediction_label_prefix == "PRED": + assert len(serialized_annotations) == 3 + assert serialized_annotations == [ + "T0\tPRED-PERSON 0 5\tHarry\n", + "T1\tPRED-ORGANIZATION 44 48\tDFKI\n", + "R0\tPRED-works_at Arg1:T0 Arg2:T1\n", + ] + else: + assert len(serialized_annotations) == 3 + assert serialized_annotations == [ + "T0\tPERSON 0 5\tHarry\n", + "T1\tORGANIZATION 44 48\tDFKI\n", + "R0\tworks_at Arg1:T0 Arg2:T1\n", + ] + + +def document_processor(document) -> TextBasedDocument: + doc = document.copy() + doc["labeled_spans"].append(LabeledSpan(start=0, end=0, label="empty")) + return doc + + +def test_write(tmp_path, document, serialized_annotations): + path = str(tmp_path) + serializer = BratSerializer( + path=path, + document_processor=document_processor, + layers=["labeled_spans", "binary_relations"], + ) + + metadata = serializer(documents=[document]) + path = metadata["path"] + ann_file = os.path.join(path, f"{document.id}.ann") + + with open(ann_file, "r") as file: + for i, line in enumerate(file.readlines()): + assert line == serialized_annotations[i] + file.close() + + +def test_write_with_exceptions_and_warnings(tmp_path, caplog, document): + path = str(tmp_path) + serializer = BratSerializer(path=path, layers=["labeled_spans", "binary_relations"]) + + # list of empty documents + with pytest.raises(Exception) as e: + serializer(documents=[]) + assert str(e.value) == "cannot serialize empty list of documents" + + # List of documents with type unexpected Document type + with pytest.raises(TypeError) as type_error: + serializer(documents=[Document()]) + assert str(type_error.value) == ( + "Document doc_0 has unexpected type: . " + "BratSerializer can only serialize TextBasedDocuments." + ) + + # Warning when metadata file already exists + metadata = serializer(documents=[document]) + full_metadata_file_name = os.path.join(metadata["path"], metadata["metadata_file_name"]) + serializer(documents=[document]) + + assert caplog.records[0].levelname == "WARNING" + assert ( + f"metadata file {full_metadata_file_name} already exists, it will be overwritten!\n" + in caplog.text + ) + + +@pytest.mark.parametrize("split", [None, "test"]) +def test_write_with_split(tmp_path, document, split): + path = str(tmp_path) + serializer = BratSerializer( + path=path, layers=["labeled_spans", "binary_relations"], split=split + ) + + metadata = serializer(documents=[document]) + real_path = metadata["path"] + if split is None: + assert real_path == os.path.join(path) + elif split is not None: + assert real_path == os.path.join(path, split)