diff --git a/src/serializer/brat.py b/src/serializer/brat.py index 01d6a59..5fde0f7 100644 --- a/src/serializer/brat.py +++ b/src/serializer/brat.py @@ -132,14 +132,15 @@ def serialize_annotations( def serialize_annotation_layers( - layers: List[AnnotationLayer], + layers: List[Tuple[AnnotationLayer, str]], gold_label_prefix: Optional[str] = None, prediction_label_prefix: Optional[str] = None, ) -> List[str]: """Serialize annotations from given annotation layers into a list of strings. Args: - layers (List[AnnotationLayer]): Annotation layers to be serialized. + layers (List[Tuple[AnnotationLayer, str]]): Annotation layers to be serialized and what + should be serialized, i.e. "gold", "prediction", or "both". gold_label_prefix (Optional[str], optional): Prefix to be added to gold labels. Defaults to None. prediction_label_prefix (Optional[str], optional): Prefix to be added to prediction labels. @@ -152,9 +153,24 @@ def serialize_annotation_layers( gold_annotation2id: Dict[Annotation, str] = {} prediction_annotation2id: Dict[Annotation, str] = {} indices: Dict[str, int] = defaultdict(int) - for layer in layers: + for layer, what in layers: + if what not in ["gold", "prediction", "both"]: + raise ValueError( + f'Invalid value for what to serialize: "{what}". Expected "gold", "prediction", or "both".' + ) + if ( + what == "both" + and gold_label_prefix is None + and prediction_label_prefix is None + and len(layer) > 0 + and len(layer.predictions) > 0 + ): + raise ValueError( + "Cannot serialize both gold and prediction annotations without a label prefix for " + "at least one of them. Consider setting gold_label_prefix or prediction_label_prefix." + ) serialized_annotations = [] - if gold_label_prefix is not None: + if what in ["gold", "both"]: serialized_gold_annotations, new_gold_ann2id = serialize_annotations( annotations=layer, indices=indices, @@ -164,16 +180,17 @@ def serialize_annotation_layers( ) serialized_annotations.extend(serialized_gold_annotations) gold_annotation2id.update(new_gold_ann2id) - serialized_predicted_annotations, new_pred_ann2id = serialize_annotations( - annotations=layer.predictions, - indices=indices, - # Predicted annotations can reference both gold and predicted annotations. - # Note that predictions take precedence over gold annotations. - annotation2id={**gold_annotation2id, **prediction_annotation2id}, - label_prefix=prediction_label_prefix, - ) - prediction_annotation2id.update(new_pred_ann2id) - serialized_annotations.extend(serialized_predicted_annotations) + if what in ["prediction", "both"]: + serialized_predicted_annotations, new_pred_ann2id = serialize_annotations( + annotations=layer.predictions, + indices=indices, + # Predicted annotations can reference both gold and predicted annotations. + # Note that predictions take precedence over gold annotations. + annotation2id={**gold_annotation2id, **prediction_annotation2id}, + label_prefix=prediction_label_prefix, + ) + prediction_annotation2id.update(new_pred_ann2id) + serialized_annotations.extend(serialized_predicted_annotations) all_serialized_annotations.extend(serialized_annotations) return all_serialized_annotations @@ -188,7 +205,8 @@ class BratSerializer(DocumentSerializer): to process documents before serialization. Attributes: - layers: The names of the annotation layers to serialize. + layers: A mapping from annotation layer names that should be serialized to what should be + serialized, i.e. "gold", "prediction", or "both". document_processor: A function or callable object to process documents before serialization. gold_label_prefix: If provided, gold annotations are serialized and its labels are prefixed with the given string. Otherwise, only predicted annotations are serialized. @@ -199,7 +217,7 @@ class BratSerializer(DocumentSerializer): def __init__( self, - layers: List[str], + layers: Dict[str, str], document_processor=None, prediction_label_prefix=None, gold_label_prefix=None, @@ -230,7 +248,7 @@ def write_with_defaults(self, **kwargs) -> Dict[str, str]: def write( cls, documents: Sequence[Document], - layers: List[str], + layers: Dict[str, str], path: str, metadata_file_name: str = METADATA_FILE_NAME, split: Optional[str] = None, @@ -263,7 +281,7 @@ def write( metadata_text[f"{file_name}"] = doc.text ann_path = os.path.join(realpath, file_name) serialized_annotations = serialize_annotation_layers( - layers=[doc[layer] for layer in layers], + layers=[(doc[layer_name], what) for layer_name, what in layers.items()], gold_label_prefix=gold_label_prefix, prediction_label_prefix=prediction_label_prefix, ) diff --git a/tests/unit/serializer/test_brat.py b/tests/unit/serializer/test_brat.py index 401c5b9..ae6e66b 100644 --- a/tests/unit/serializer/test_brat.py +++ b/tests/unit/serializer/test_brat.py @@ -151,32 +151,19 @@ def document(): return document -@pytest.fixture -def serialized_annotations( - document, - gold_label_prefix=None, - prediction_label_prefix=None, -): - return serialize_annotation_layers( - layers=[document.labeled_spans, document.binary_relations], - gold_label_prefix=gold_label_prefix, - prediction_label_prefix=prediction_label_prefix, - ) - - @pytest.mark.parametrize( - "gold_label_prefix, prediction_label_prefix", - [(None, None), ("GOLD", None), (None, "PRED"), ("GOLD", "PRED")], + "what", + ["gold", "prediction", "both"], ) -def test_serialize_annotations(document, gold_label_prefix, prediction_label_prefix): +def test_serialize_annotations(document, what): + serialized_annotations = serialize_annotation_layers( - layers=[document.labeled_spans, document.binary_relations], - gold_label_prefix=gold_label_prefix, - prediction_label_prefix=prediction_label_prefix, + layers=[(document.labeled_spans, what), (document.binary_relations, what)], + gold_label_prefix="GOLD", + prediction_label_prefix="PRED" if what == "both" else None, ) - if gold_label_prefix == "GOLD" and prediction_label_prefix == "PRED": - assert len(serialized_annotations) == 8 + if what == "both": assert serialized_annotations == [ "T0\tGOLD-PERSON 0 5\tHarry\n", "T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n", @@ -187,32 +174,45 @@ def test_serialize_annotations(document, gold_label_prefix, prediction_label_pre "R1\tGOLD-works_at Arg1:T0 Arg2:T2\n", "R2\tPRED-works_at Arg1:T3 Arg2:T4\n", ] - elif gold_label_prefix == "GOLD" and prediction_label_prefix is None: - assert len(serialized_annotations) == 8 + elif what == "gold": assert serialized_annotations == [ "T0\tGOLD-PERSON 0 5\tHarry\n", "T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n", "T2\tGOLD-ORGANIZATION 44 48\tDFKI\n", - "T3\tPERSON 0 5\tHarry\n", - "T4\tORGANIZATION 44 48\tDFKI\n", "R0\tGOLD-lives_in Arg1:T0 Arg2:T1\n", "R1\tGOLD-works_at Arg1:T0 Arg2:T2\n", - "R2\tworks_at Arg1:T3 Arg2:T4\n", ] - elif gold_label_prefix is None and prediction_label_prefix == "PRED": - assert len(serialized_annotations) == 3 - assert serialized_annotations == [ - "T0\tPRED-PERSON 0 5\tHarry\n", - "T1\tPRED-ORGANIZATION 44 48\tDFKI\n", - "R0\tPRED-works_at Arg1:T0 Arg2:T1\n", - ] - else: - assert len(serialized_annotations) == 3 + elif what == "prediction": assert serialized_annotations == [ "T0\tPERSON 0 5\tHarry\n", "T1\tORGANIZATION 44 48\tDFKI\n", "R0\tworks_at Arg1:T0 Arg2:T1\n", ] + else: + raise ValueError(f"Unexpected value for what: {what}") + + +def test_serialize_annotations_unknown_what(document): + with pytest.raises(ValueError) as e: + serialize_annotation_layers( + layers=[(document.labeled_spans, "dummy"), (document.binary_relations, "dummy")], + ) + assert ( + str(e.value) + == 'Invalid value for what to serialize: "dummy". Expected "gold", "prediction", or "both".' + ) + + +def test_serialize_annotations_missing_prefix(document): + + with pytest.raises(ValueError) as e: + serialize_annotation_layers( + layers=[(document.labeled_spans, "both")], + ) + assert str(e.value) == ( + "Cannot serialize both gold and prediction annotations without a label prefix " + "for at least one of them. Consider setting gold_label_prefix or prediction_label_prefix." + ) def document_processor(document) -> TextBasedDocument: @@ -221,12 +221,12 @@ def document_processor(document) -> TextBasedDocument: return doc -def test_write(tmp_path, document, serialized_annotations): +def test_write(tmp_path, document): path = str(tmp_path) serializer = BratSerializer( path=path, document_processor=document_processor, - layers=["labeled_spans", "binary_relations"], + layers={"labeled_spans": "prediction", "binary_relations": "prediction"}, ) metadata = serializer(documents=[document]) @@ -234,14 +234,20 @@ def test_write(tmp_path, document, serialized_annotations): ann_file = os.path.join(path, f"{document.id}.ann") with open(ann_file, "r") as file: - for i, line in enumerate(file.readlines()): - assert line == serialized_annotations[i] - file.close() + lines = file.readlines() + + assert lines == [ + "T0\tPERSON 0 5\tHarry\n", + "T1\tORGANIZATION 44 48\tDFKI\n", + "R0\tworks_at Arg1:T0 Arg2:T1\n", + ] def test_write_with_exceptions_and_warnings(tmp_path, caplog, document): path = str(tmp_path) - serializer = BratSerializer(path=path, layers=["labeled_spans", "binary_relations"]) + serializer = BratSerializer( + path=path, layers={"labeled_spans": "prediction", "binary_relations": "prediction"} + ) # list of empty documents with pytest.raises(Exception) as e: @@ -272,7 +278,9 @@ def test_write_with_exceptions_and_warnings(tmp_path, caplog, document): def test_write_with_split(tmp_path, document, split): path = str(tmp_path) serializer = BratSerializer( - path=path, layers=["labeled_spans", "binary_relations"], split=split + path=path, + layers={"labeled_spans": "prediction", "binary_relations": "prediction"}, + split=split, ) metadata = serializer(documents=[document]) @@ -281,3 +289,95 @@ def test_write_with_split(tmp_path, document, split): assert real_path == os.path.join(path) elif split is not None: assert real_path == os.path.join(path, split) + + +@pytest.fixture +def document_only_gold_spans_both_relations(): + document = TextDocumentWithLabeledSpansAndBinaryRelations( + text="Harry lives in Berlin, Germany. He works at DFKI.", + id="tmp", + metadata={ + "span_ids": [], + "relation_ids": [], + "prediction_span_ids": [], + "prediction_relation_ids": [], + }, + ) + + document.labeled_spans.extend( + [ + LabeledSpan(start=0, end=5, label="PERSON"), + LabeledSpan(start=15, end=30, label="LOCATION"), + LabeledSpan(start=44, end=48, label="ORGANIZATION"), + ] + ) + document.metadata["span_ids"].extend(["T100", "T101", "T102"]) + + assert str(document.labeled_spans[0]) == "Harry" + assert str(document.labeled_spans[1]) == "Berlin, Germany" + assert str(document.labeled_spans[2]) == "DFKI" + + document.binary_relations.predictions.extend( + [ + BinaryRelation( + head=document.labeled_spans[0], + tail=document.labeled_spans[2], + label="works_at", + ), + ] + ) + document.metadata["prediction_relation_ids"].extend(["R200"]) + + document.binary_relations.extend( + [ + BinaryRelation( + head=document.labeled_spans[0], + tail=document.labeled_spans[1], + label="lives_in", + ), + BinaryRelation( + head=document.labeled_spans[0], + tail=document.labeled_spans[2], + label="works_at", + ), + ] + ) + document.metadata["relation_ids"].extend(["R100", "R101"]) + + return document + + +@pytest.mark.parametrize( + "what", + [("gold", "prediction"), ("both", "prediction"), ("gold", "both"), ("both", "both")], +) +def test_serialize_annotations_only_gold_spans_both_relations( + document_only_gold_spans_both_relations, what +): + + serialized_annotations = serialize_annotation_layers( + layers=[ + (document_only_gold_spans_both_relations.labeled_spans, what[0]), + (document_only_gold_spans_both_relations.binary_relations, what[1]), + ], + gold_label_prefix="GOLD", + prediction_label_prefix="PRED", + ) + if what in [("gold", "prediction"), ("both", "prediction")]: + assert serialized_annotations == [ + "T0\tGOLD-PERSON 0 5\tHarry\n", + "T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n", + "T2\tGOLD-ORGANIZATION 44 48\tDFKI\n", + "R0\tPRED-works_at Arg1:T0 Arg2:T2\n", + ] + elif what in [("gold", "both"), ("both", "both")]: + assert serialized_annotations == [ + "T0\tGOLD-PERSON 0 5\tHarry\n", + "T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n", + "T2\tGOLD-ORGANIZATION 44 48\tDFKI\n", + "R0\tGOLD-lives_in Arg1:T0 Arg2:T1\n", + "R1\tGOLD-works_at Arg1:T0 Arg2:T2\n", + "R2\tPRED-works_at Arg1:T0 Arg2:T2\n", + ] + else: + raise ValueError(f"Unexpected value for what: {what}")