Skip to content

Commit

Permalink
make layers parameter a dict from layer names to what to serialize, i…
Browse files Browse the repository at this point in the history
….e. "gold", "prediction", >"both"
  • Loading branch information
ArneBinder committed Jul 27, 2024
1 parent 43eeb8c commit 0917177
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 60 deletions.
54 changes: 36 additions & 18 deletions src/serializer/brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,15 @@ def serialize_annotations(


def serialize_annotation_layers(
layers: List[AnnotationLayer],
layers: List[Tuple[AnnotationLayer, str]],
gold_label_prefix: Optional[str] = None,
prediction_label_prefix: Optional[str] = None,
) -> List[str]:
"""Serialize annotations from given annotation layers into a list of strings.
Args:
layers (List[AnnotationLayer]): Annotation layers to be serialized.
layers (List[Tuple[AnnotationLayer, str]]): Annotation layers to be serialized and what
should be serialized, i.e. "gold", "prediction", or "both".
gold_label_prefix (Optional[str], optional): Prefix to be added to gold labels.
Defaults to None.
prediction_label_prefix (Optional[str], optional): Prefix to be added to prediction labels.
Expand All @@ -152,9 +153,24 @@ def serialize_annotation_layers(
gold_annotation2id: Dict[Annotation, str] = {}
prediction_annotation2id: Dict[Annotation, str] = {}
indices: Dict[str, int] = defaultdict(int)
for layer in layers:
for layer, what in layers:
if what not in ["gold", "prediction", "both"]:
raise ValueError(
f'Invalid value for what to serialize: "{what}". Expected "gold", "prediction", or "both".'
)
if (
what == "both"
and gold_label_prefix is None
and prediction_label_prefix is None
and len(layer) > 0
and len(layer.predictions) > 0
):
raise ValueError(
"Cannot serialize both gold and prediction annotations without a label prefix for "
"at least one of them. Consider setting gold_label_prefix or prediction_label_prefix."
)
serialized_annotations = []
if gold_label_prefix is not None:
if what in ["gold", "both"]:
serialized_gold_annotations, new_gold_ann2id = serialize_annotations(
annotations=layer,
indices=indices,
Expand All @@ -164,16 +180,17 @@ def serialize_annotation_layers(
)
serialized_annotations.extend(serialized_gold_annotations)
gold_annotation2id.update(new_gold_ann2id)
serialized_predicted_annotations, new_pred_ann2id = serialize_annotations(
annotations=layer.predictions,
indices=indices,
# Predicted annotations can reference both gold and predicted annotations.
# Note that predictions take precedence over gold annotations.
annotation2id={**gold_annotation2id, **prediction_annotation2id},
label_prefix=prediction_label_prefix,
)
prediction_annotation2id.update(new_pred_ann2id)
serialized_annotations.extend(serialized_predicted_annotations)
if what in ["prediction", "both"]:
serialized_predicted_annotations, new_pred_ann2id = serialize_annotations(
annotations=layer.predictions,
indices=indices,
# Predicted annotations can reference both gold and predicted annotations.
# Note that predictions take precedence over gold annotations.
annotation2id={**gold_annotation2id, **prediction_annotation2id},
label_prefix=prediction_label_prefix,
)
prediction_annotation2id.update(new_pred_ann2id)
serialized_annotations.extend(serialized_predicted_annotations)
all_serialized_annotations.extend(serialized_annotations)
return all_serialized_annotations

Expand All @@ -188,7 +205,8 @@ class BratSerializer(DocumentSerializer):
to process documents before serialization.
Attributes:
layers: The names of the annotation layers to serialize.
layers: A mapping from annotation layer names that should be serialized to what should be
serialized, i.e. "gold", "prediction", or "both".
document_processor: A function or callable object to process documents before serialization.
gold_label_prefix: If provided, gold annotations are serialized and its labels are prefixed
with the given string. Otherwise, only predicted annotations are serialized.
Expand All @@ -199,7 +217,7 @@ class BratSerializer(DocumentSerializer):

def __init__(
self,
layers: List[str],
layers: Dict[str, str],
document_processor=None,
prediction_label_prefix=None,
gold_label_prefix=None,
Expand Down Expand Up @@ -230,7 +248,7 @@ def write_with_defaults(self, **kwargs) -> Dict[str, str]:
def write(
cls,
documents: Sequence[Document],
layers: List[str],
layers: Dict[str, str],
path: str,
metadata_file_name: str = METADATA_FILE_NAME,
split: Optional[str] = None,
Expand Down Expand Up @@ -263,7 +281,7 @@ def write(
metadata_text[f"{file_name}"] = doc.text
ann_path = os.path.join(realpath, file_name)
serialized_annotations = serialize_annotation_layers(
layers=[doc[layer] for layer in layers],
layers=[(doc[layer_name], what) for layer_name, what in layers.items()],
gold_label_prefix=gold_label_prefix,
prediction_label_prefix=prediction_label_prefix,
)
Expand Down
184 changes: 142 additions & 42 deletions tests/unit/serializer/test_brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,32 +151,19 @@ def document():
return document


@pytest.fixture
def serialized_annotations(
document,
gold_label_prefix=None,
prediction_label_prefix=None,
):
return serialize_annotation_layers(
layers=[document.labeled_spans, document.binary_relations],
gold_label_prefix=gold_label_prefix,
prediction_label_prefix=prediction_label_prefix,
)


@pytest.mark.parametrize(
"gold_label_prefix, prediction_label_prefix",
[(None, None), ("GOLD", None), (None, "PRED"), ("GOLD", "PRED")],
"what",
["gold", "prediction", "both"],
)
def test_serialize_annotations(document, gold_label_prefix, prediction_label_prefix):
def test_serialize_annotations(document, what):

serialized_annotations = serialize_annotation_layers(
layers=[document.labeled_spans, document.binary_relations],
gold_label_prefix=gold_label_prefix,
prediction_label_prefix=prediction_label_prefix,
layers=[(document.labeled_spans, what), (document.binary_relations, what)],
gold_label_prefix="GOLD",
prediction_label_prefix="PRED" if what == "both" else None,
)

if gold_label_prefix is not None and prediction_label_prefix is not None:
assert len(serialized_annotations) == 8
if what == "both":
assert serialized_annotations == [
"T0\tGOLD-PERSON 0 5\tHarry\n",
"T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n",
Expand All @@ -187,32 +174,45 @@ def test_serialize_annotations(document, gold_label_prefix, prediction_label_pre
"R1\tGOLD-works_at Arg1:T0 Arg2:T2\n",
"R2\tPRED-works_at Arg1:T3 Arg2:T4\n",
]
elif gold_label_prefix is not None and prediction_label_prefix is None:
assert len(serialized_annotations) == 8
elif what == "gold":
assert serialized_annotations == [
"T0\tGOLD-PERSON 0 5\tHarry\n",
"T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n",
"T2\tGOLD-ORGANIZATION 44 48\tDFKI\n",
"T3\tPERSON 0 5\tHarry\n",
"T4\tORGANIZATION 44 48\tDFKI\n",
"R0\tGOLD-lives_in Arg1:T0 Arg2:T1\n",
"R1\tGOLD-works_at Arg1:T0 Arg2:T2\n",
"R2\tworks_at Arg1:T3 Arg2:T4\n",
]
elif gold_label_prefix is None and prediction_label_prefix is not None:
assert len(serialized_annotations) == 3
assert serialized_annotations == [
"T0\tPRED-PERSON 0 5\tHarry\n",
"T1\tPRED-ORGANIZATION 44 48\tDFKI\n",
"R0\tPRED-works_at Arg1:T0 Arg2:T1\n",
]
else:
assert len(serialized_annotations) == 3
elif what == "prediction":
assert serialized_annotations == [
"T0\tPERSON 0 5\tHarry\n",
"T1\tORGANIZATION 44 48\tDFKI\n",
"R0\tworks_at Arg1:T0 Arg2:T1\n",
]
else:
raise ValueError(f"Unexpected value for what: {what}")


def test_serialize_annotations_unknown_what(document):
with pytest.raises(ValueError) as e:
serialize_annotation_layers(
layers=[(document.labeled_spans, "dummy"), (document.binary_relations, "dummy")],
)
assert (
str(e.value)
== 'Invalid value for what to serialize: "dummy". Expected "gold", "prediction", or "both".'
)


def test_serialize_annotations_missing_prefix(document):

with pytest.raises(ValueError) as e:
serialize_annotation_layers(
layers=[(document.labeled_spans, "both")],
)
assert str(e.value) == (
"Cannot serialize both gold and prediction annotations without a label prefix "
"for at least one of them. Consider setting gold_label_prefix or prediction_label_prefix."
)


def document_processor(document) -> TextBasedDocument:
Expand All @@ -221,27 +221,33 @@ def document_processor(document) -> TextBasedDocument:
return doc


def test_write(tmp_path, document, serialized_annotations):
def test_write(tmp_path, document):
path = str(tmp_path)
serializer = BratSerializer(
path=path,
document_processor=document_processor,
layers=["labeled_spans", "binary_relations"],
layers={"labeled_spans": "prediction", "binary_relations": "prediction"},
)

metadata = serializer(documents=[document])
path = metadata["path"]
ann_file = os.path.join(path, f"{document.id}.ann")

with open(ann_file, "r") as file:
for i, line in enumerate(file.readlines()):
assert line == serialized_annotations[i]
file.close()
lines = file.readlines()

assert lines == [
"T0\tPERSON 0 5\tHarry\n",
"T1\tORGANIZATION 44 48\tDFKI\n",
"R0\tworks_at Arg1:T0 Arg2:T1\n",
]


def test_write_with_exceptions_and_warnings(tmp_path, caplog, document):
path = str(tmp_path)
serializer = BratSerializer(path=path, layers=["labeled_spans", "binary_relations"])
serializer = BratSerializer(
path=path, layers={"labeled_spans": "prediction", "binary_relations": "prediction"}
)

# list of empty documents
with pytest.raises(Exception) as e:
Expand Down Expand Up @@ -272,7 +278,9 @@ def test_write_with_exceptions_and_warnings(tmp_path, caplog, document):
def test_write_with_split(tmp_path, document, split):
path = str(tmp_path)
serializer = BratSerializer(
path=path, layers=["labeled_spans", "binary_relations"], split=split
path=path,
layers={"labeled_spans": "prediction", "binary_relations": "prediction"},
split=split,
)

metadata = serializer(documents=[document])
Expand All @@ -281,3 +289,95 @@ def test_write_with_split(tmp_path, document, split):
assert real_path == os.path.join(path)
elif split is not None:
assert real_path == os.path.join(path, split)


@pytest.fixture
def document_only_gold_spans_both_relations():
document = TextDocumentWithLabeledSpansAndBinaryRelations(
text="Harry lives in Berlin, Germany. He works at DFKI.",
id="tmp",
metadata={
"span_ids": [],
"relation_ids": [],
"prediction_span_ids": [],
"prediction_relation_ids": [],
},
)

document.labeled_spans.extend(
[
LabeledSpan(start=0, end=5, label="PERSON"),
LabeledSpan(start=15, end=30, label="LOCATION"),
LabeledSpan(start=44, end=48, label="ORGANIZATION"),
]
)
document.metadata["span_ids"].extend(["T100", "T101", "T102"])

assert str(document.labeled_spans[0]) == "Harry"
assert str(document.labeled_spans[1]) == "Berlin, Germany"
assert str(document.labeled_spans[2]) == "DFKI"

document.binary_relations.predictions.extend(
[
BinaryRelation(
head=document.labeled_spans[0],
tail=document.labeled_spans[2],
label="works_at",
),
]
)
document.metadata["prediction_relation_ids"].extend(["R200"])

document.binary_relations.extend(
[
BinaryRelation(
head=document.labeled_spans[0],
tail=document.labeled_spans[1],
label="lives_in",
),
BinaryRelation(
head=document.labeled_spans[0],
tail=document.labeled_spans[2],
label="works_at",
),
]
)
document.metadata["relation_ids"].extend(["R100", "R101"])

return document


@pytest.mark.parametrize(
"what",
[("gold", "prediction"), ("both", "prediction"), ("gold", "both"), ("both", "both")],
)
def test_serialize_annotations_only_gold_spans_both_relations(
document_only_gold_spans_both_relations, what
):

serialized_annotations = serialize_annotation_layers(
layers=[
(document_only_gold_spans_both_relations.labeled_spans, what[0]),
(document_only_gold_spans_both_relations.binary_relations, what[1]),
],
gold_label_prefix="GOLD",
prediction_label_prefix="PRED",
)
if what in [("gold", "prediction"), ("both", "prediction")]:
assert serialized_annotations == [
"T0\tGOLD-PERSON 0 5\tHarry\n",
"T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n",
"T2\tGOLD-ORGANIZATION 44 48\tDFKI\n",
"R0\tPRED-works_at Arg1:T0 Arg2:T2\n",
]
elif what in [("gold", "both"), ("both", "both")]:
assert serialized_annotations == [
"T0\tGOLD-PERSON 0 5\tHarry\n",
"T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n",
"T2\tGOLD-ORGANIZATION 44 48\tDFKI\n",
"R0\tGOLD-lives_in Arg1:T0 Arg2:T1\n",
"R1\tGOLD-works_at Arg1:T0 Arg2:T2\n",
"R2\tPRED-works_at Arg1:T0 Arg2:T2\n",
]
else:
raise ValueError(f"Unexpected value for what: {what}")

0 comments on commit 0917177

Please sign in to comment.