From 97d266fc180db8a49a015b5bfccdab1a699e50cb Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sat, 27 Jul 2024 20:58:08 +0200 Subject: [PATCH] add metadata_gold_id_keys and metadata_prediction_id_keys to BratSerializer --- src/serializer/brat.py | 98 ++++++++++++++++++++++++------ tests/unit/serializer/test_brat.py | 86 +++++++++++++++++++++++++- 2 files changed, 164 insertions(+), 20 deletions(-) diff --git a/src/serializer/brat.py b/src/serializer/brat.py index 5fde0f7..03ae797 100644 --- a/src/serializer/brat.py +++ b/src/serializer/brat.py @@ -113,20 +113,29 @@ def serialize_annotations( indices: Dict[str, int], annotation2id: Dict[Annotation, str], label_prefix: Optional[str] = None, + annotation_ids: Optional[List[str]] = None, ) -> Tuple[List[str], Dict[Annotation, str]]: serialized_annotations = [] new_annotation2id: Dict[Annotation, str] = {} - for annotation in annotations: + for idx, annotation in enumerate(annotations): annotation_type, serialized_annotation = serialize_annotation( annotation=annotation, annotation2id=annotation2id, label_prefix=label_prefix, ) - idx = indices[annotation_type] - annotation_id = f"{annotation_type}{idx}" + if annotation_ids is not None: + if indices.get(annotation_type, 0) > 0: + raise ValueError( + "Cannot specify annotation IDs for the same type (e.g. T or R) if there are " + "other annotations of the same type without an ID." + ) + annotation_id = annotation_ids[idx] + else: + index = indices[annotation_type] + annotation_id = f"{annotation_type}{index}" + indices[annotation_type] += 1 serialized_annotations.append(f"{annotation_id}\t{serialized_annotation}") new_annotation2id[annotation] = annotation_id - indices[annotation_type] += 1 return serialized_annotations, new_annotation2id @@ -135,6 +144,8 @@ def serialize_annotation_layers( layers: List[Tuple[AnnotationLayer, str]], gold_label_prefix: Optional[str] = None, prediction_label_prefix: Optional[str] = None, + gold_annotation_ids: Optional[List[Optional[List[str]]]] = None, + prediction_annotation_ids: Optional[List[Optional[List[str]]]] = None, ) -> List[str]: """Serialize annotations from given annotation layers into a list of strings. @@ -145,15 +156,20 @@ def serialize_annotation_layers( Defaults to None. prediction_label_prefix (Optional[str], optional): Prefix to be added to prediction labels. Defaults to None. + gold_annotation_ids (Optional[List[Optional[str]]], optional): List of gold annotation IDs. + If provided, the length should match the number of layers. Defaults to None. + prediction_annotation_ids (Optional[List[Optional[str]]], optional): List of prediction + annotation IDs. If provided, the length should match the number of layers. Defaults to None. Returns: List[str]: List of serialized annotations. """ + all_serialized_annotations = [] gold_annotation2id: Dict[Annotation, str] = {} prediction_annotation2id: Dict[Annotation, str] = {} indices: Dict[str, int] = defaultdict(int) - for layer, what in layers: + for idx, (layer, what) in enumerate(layers): if what not in ["gold", "prediction", "both"]: raise ValueError( f'Invalid value for what to serialize: "{what}". Expected "gold", "prediction", or "both".' @@ -171,16 +187,46 @@ def serialize_annotation_layers( ) serialized_annotations = [] if what in ["gold", "both"]: + if gold_annotation_ids is not None: + if len(gold_annotation_ids) <= idx: + raise ValueError( + "gold_annotation_ids should have the same length as the number of layers." + ) + current_gold_annotation_ids = gold_annotation_ids[idx] + if current_gold_annotation_ids is not None and len( + current_gold_annotation_ids + ) != len(layer): + raise ValueError( + "gold_annotation_ids should have the same length as the number of gold annotations." + ) + else: + current_gold_annotation_ids = None + serialized_gold_annotations, new_gold_ann2id = serialize_annotations( annotations=layer, indices=indices, # gold annotations can only reference other gold annotations annotation2id=gold_annotation2id, label_prefix=gold_label_prefix, + annotation_ids=current_gold_annotation_ids, ) serialized_annotations.extend(serialized_gold_annotations) gold_annotation2id.update(new_gold_ann2id) if what in ["prediction", "both"]: + if prediction_annotation_ids is not None: + if len(prediction_annotation_ids) <= idx: + raise ValueError( + "prediction_annotation_ids should have the same length as the number of layers." + ) + current_prediction_annotation_ids = prediction_annotation_ids[idx] + if current_prediction_annotation_ids is not None and len( + current_prediction_annotation_ids + ) != len(layer.predictions): + raise ValueError( + "prediction_annotation_ids should have the same length as the number of prediction annotations." + ) + else: + current_prediction_annotation_ids = None serialized_predicted_annotations, new_pred_ann2id = serialize_annotations( annotations=layer.predictions, indices=indices, @@ -188,6 +234,7 @@ def serialize_annotation_layers( # Note that predictions take precedence over gold annotations. annotation2id={**gold_annotation2id, **prediction_annotation2id}, label_prefix=prediction_label_prefix, + annotation_ids=current_prediction_annotation_ids, ) prediction_annotation2id.update(new_pred_ann2id) serialized_annotations.extend(serialized_predicted_annotations) @@ -200,10 +247,6 @@ class BratSerializer(DocumentSerializer): specify the annotation layers to serialize. For now, it supports layers containing LabeledSpan, LabeledMultiSpan, and BinaryRelation annotations. - If a gold_label_prefix is provided, the gold annotations are serialized with the given prefix. - Otherwise, only the predicted annotations are serialized. A document_processor can be provided - to process documents before serialization. - Attributes: layers: A mapping from annotation layer names that should be serialized to what should be serialized, i.e. "gold", "prediction", or "both". @@ -212,21 +255,20 @@ class BratSerializer(DocumentSerializer): with the given string. Otherwise, only predicted annotations are serialized. prediction_label_prefix: If provided, labels of predicted annotations are prefixed with the given string. - default_kwargs: Additional keyword arguments to be used as defaults during serialization. + metadata_gold_id_keys: A dictionary mapping layer names to metadata keys that contain the + gold annotation IDs. + metadata_prediction_id_keys: A dictionary mapping layer names to metadata keys that contain + the prediction annotation IDs. """ def __init__( self, layers: Dict[str, str], document_processor=None, - prediction_label_prefix=None, - gold_label_prefix=None, **kwargs, ): self.document_processor = document_processor self.layers = layers - self.prediction_label_prefix = prediction_label_prefix - self.gold_label_prefix = gold_label_prefix self.default_kwargs = kwargs def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]: @@ -235,8 +277,6 @@ def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]: return self.write_with_defaults( documents=documents, layers=self.layers, - prediction_label_prefix=self.prediction_label_prefix, - gold_label_prefix=self.gold_label_prefix, **kwargs, ) @@ -254,6 +294,8 @@ def write( split: Optional[str] = None, gold_label_prefix: Optional[str] = None, prediction_label_prefix: Optional[str] = None, + metadata_gold_id_keys: Optional[Dict[str, str]] = None, + metadata_prediction_id_keys: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: realpath = os.path.realpath(path) @@ -280,10 +322,32 @@ def write( file_name = f"{doc_id}.ann" metadata_text[f"{file_name}"] = doc.text ann_path = os.path.join(realpath, file_name) + layer_names = list(layers) + if metadata_gold_id_keys is not None: + gold_annotation_ids = [ + doc.metadata[metadata_gold_id_keys[layer_name]] + if layer_name in metadata_gold_id_keys + else None + for layer_name in layer_names + ] + else: + gold_annotation_ids = None + + if metadata_prediction_id_keys is not None: + prediction_annotation_ids = [ + doc.metadata[metadata_prediction_id_keys[layer_name]] + if layer_name in metadata_prediction_id_keys + else None + for layer_name in layer_names + ] + else: + prediction_annotation_ids = None serialized_annotations = serialize_annotation_layers( - layers=[(doc[layer_name], what) for layer_name, what in layers.items()], + layers=[(doc[layer_name], layers[layer_name]) for layer_name in layer_names], gold_label_prefix=gold_label_prefix, prediction_label_prefix=prediction_label_prefix, + gold_annotation_ids=gold_annotation_ids, + prediction_annotation_ids=prediction_annotation_ids, ) with open(ann_path, "w+") as f: f.writelines(serialized_annotations) diff --git a/tests/unit/serializer/test_brat.py b/tests/unit/serializer/test_brat.py index ae6e66b..5918a35 100644 --- a/tests/unit/serializer/test_brat.py +++ b/tests/unit/serializer/test_brat.py @@ -100,7 +100,14 @@ class TextDocumentWithLabeledMultiSpansAndBinaryRelations(TextBasedDocument): @pytest.fixture def document(): document = TextDocumentWithLabeledSpansAndBinaryRelations( - text="Harry lives in Berlin, Germany. He works at DFKI.", id="tmp" + text="Harry lives in Berlin, Germany. He works at DFKI.", + id="tmp", + metadata={ + "span_ids": [], + "relation_ids": [], + "prediction_span_ids": [], + "prediction_relation_ids": [], + }, ) document.labeled_spans.predictions.extend( [ @@ -108,6 +115,7 @@ def document(): LabeledSpan(start=44, end=48, label="ORGANIZATION"), ] ) + document.metadata["prediction_span_ids"].extend(["T200", "T201"]) assert str(document.labeled_spans.predictions[0]) == "Harry" assert str(document.labeled_spans.predictions[1]) == "DFKI" @@ -119,6 +127,8 @@ def document(): LabeledSpan(start=44, end=48, label="ORGANIZATION"), ] ) + document.metadata["span_ids"].extend(["T100", "T101", "T102"]) + assert str(document.labeled_spans[0]) == "Harry" assert str(document.labeled_spans[1]) == "Berlin, Germany" assert str(document.labeled_spans[2]) == "DFKI" @@ -132,6 +142,7 @@ def document(): ), ] ) + document.metadata["prediction_relation_ids"].extend(["R200"]) document.binary_relations.extend( [ @@ -147,6 +158,7 @@ def document(): ), ] ) + document.metadata["relation_ids"].extend(["R100", "R101"]) return document @@ -192,6 +204,51 @@ def test_serialize_annotations(document, what): raise ValueError(f"Unexpected value for what: {what}") +@pytest.mark.parametrize( + "what", + ["gold", "prediction", "both"], +) +def test_serialize_annotations_with_annotation_ids(document, what): + serialized_annotations = serialize_annotation_layers( + layers=[(document.labeled_spans, what), (document.binary_relations, what)], + gold_label_prefix="GOLD", + prediction_label_prefix="PRED" if what == "both" else None, + gold_annotation_ids=[document.metadata["span_ids"], document.metadata["relation_ids"]], + prediction_annotation_ids=[ + document.metadata["prediction_span_ids"], + document.metadata["prediction_relation_ids"], + ], + ) + + if what == "both": + assert serialized_annotations == [ + "T100\tGOLD-PERSON 0 5\tHarry\n", + "T101\tGOLD-LOCATION 15 30\tBerlin, Germany\n", + "T102\tGOLD-ORGANIZATION 44 48\tDFKI\n", + "T200\tPRED-PERSON 0 5\tHarry\n", + "T201\tPRED-ORGANIZATION 44 48\tDFKI\n", + "R100\tGOLD-lives_in Arg1:T100 Arg2:T101\n", + "R101\tGOLD-works_at Arg1:T100 Arg2:T102\n", + "R200\tPRED-works_at Arg1:T200 Arg2:T201\n", + ] + elif what == "gold": + assert serialized_annotations == [ + "T100\tGOLD-PERSON 0 5\tHarry\n", + "T101\tGOLD-LOCATION 15 30\tBerlin, Germany\n", + "T102\tGOLD-ORGANIZATION 44 48\tDFKI\n", + "R100\tGOLD-lives_in Arg1:T100 Arg2:T101\n", + "R101\tGOLD-works_at Arg1:T100 Arg2:T102\n", + ] + elif what == "prediction": + assert serialized_annotations == [ + "T200\tPERSON 0 5\tHarry\n", + "T201\tORGANIZATION 44 48\tDFKI\n", + "R200\tworks_at Arg1:T200 Arg2:T201\n", + ] + else: + raise ValueError(f"Unexpected value for what: {what}") + + def test_serialize_annotations_unknown_what(document): with pytest.raises(ValueError) as e: serialize_annotation_layers( @@ -215,7 +272,7 @@ def test_serialize_annotations_missing_prefix(document): ) -def document_processor(document) -> TextBasedDocument: +def append_empty_span_to_labeled_spans(document) -> TextBasedDocument: doc = document.copy() doc["labeled_spans"].append(LabeledSpan(start=0, end=0, label="empty")) return doc @@ -225,7 +282,7 @@ def test_write(tmp_path, document): path = str(tmp_path) serializer = BratSerializer( path=path, - document_processor=document_processor, + document_processor=append_empty_span_to_labeled_spans, layers={"labeled_spans": "prediction", "binary_relations": "prediction"}, ) @@ -243,6 +300,29 @@ def test_write(tmp_path, document): ] +def test_write_with_annotation_ids(tmp_path, document): + path = str(tmp_path) + serializer = BratSerializer( + path=path, + layers={"labeled_spans": "gold", "binary_relations": "prediction"}, + metadata_gold_id_keys={"labeled_spans": "span_ids"}, + ) + + metadata = serializer(documents=[document]) + path = metadata["path"] + ann_file = os.path.join(path, f"{document.id}.ann") + + with open(ann_file, "r") as file: + lines = file.readlines() + + assert lines == [ + "T100\tPERSON 0 5\tHarry\n", + "T101\tLOCATION 15 30\tBerlin, Germany\n", + "T102\tORGANIZATION 44 48\tDFKI\n", + "R0\tworks_at Arg1:T100 Arg2:T102\n", + ] + + def test_write_with_exceptions_and_warnings(tmp_path, caplog, document): path = str(tmp_path) serializer = BratSerializer(