Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

brat serializer: make layers parameter a dict #170

Merged
merged 2 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions src/serializer/brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,15 @@ def serialize_annotations(


def serialize_annotation_layers(
layers: List[AnnotationLayer],
layers: List[Tuple[AnnotationLayer, str]],
gold_label_prefix: Optional[str] = None,
prediction_label_prefix: Optional[str] = None,
) -> List[str]:
"""Serialize annotations from given annotation layers into a list of strings.

Args:
layers (List[AnnotationLayer]): Annotation layers to be serialized.
layers (List[Tuple[AnnotationLayer, str]]): Annotation layers to be serialized and what
should be serialized, i.e. "gold", "prediction", or "both".
gold_label_prefix (Optional[str], optional): Prefix to be added to gold labels.
Defaults to None.
prediction_label_prefix (Optional[str], optional): Prefix to be added to prediction labels.
Expand All @@ -152,9 +153,24 @@ def serialize_annotation_layers(
gold_annotation2id: Dict[Annotation, str] = {}
prediction_annotation2id: Dict[Annotation, str] = {}
indices: Dict[str, int] = defaultdict(int)
for layer in layers:
for layer, what in layers:
if what not in ["gold", "prediction", "both"]:
raise ValueError(
f'Invalid value for what to serialize: "{what}". Expected "gold", "prediction", or "both".'
)
if (
what == "both"
and gold_label_prefix is None
and prediction_label_prefix is None
and len(layer) > 0
and len(layer.predictions) > 0
):
raise ValueError(
"Cannot serialize both gold and prediction annotations without a label prefix for "
"at least one of them. Consider setting gold_label_prefix or prediction_label_prefix."
)
serialized_annotations = []
if gold_label_prefix is not None:
if what in ["gold", "both"]:
serialized_gold_annotations, new_gold_ann2id = serialize_annotations(
annotations=layer,
indices=indices,
Expand All @@ -164,16 +180,17 @@ def serialize_annotation_layers(
)
serialized_annotations.extend(serialized_gold_annotations)
gold_annotation2id.update(new_gold_ann2id)
serialized_predicted_annotations, new_pred_ann2id = serialize_annotations(
annotations=layer.predictions,
indices=indices,
# Predicted annotations can reference both gold and predicted annotations.
# Note that predictions take precedence over gold annotations.
annotation2id={**gold_annotation2id, **prediction_annotation2id},
label_prefix=prediction_label_prefix,
)
prediction_annotation2id.update(new_pred_ann2id)
serialized_annotations.extend(serialized_predicted_annotations)
if what in ["prediction", "both"]:
serialized_predicted_annotations, new_pred_ann2id = serialize_annotations(
annotations=layer.predictions,
indices=indices,
# Predicted annotations can reference both gold and predicted annotations.
# Note that predictions take precedence over gold annotations.
annotation2id={**gold_annotation2id, **prediction_annotation2id},
label_prefix=prediction_label_prefix,
)
prediction_annotation2id.update(new_pred_ann2id)
serialized_annotations.extend(serialized_predicted_annotations)
all_serialized_annotations.extend(serialized_annotations)
return all_serialized_annotations

Expand All @@ -188,7 +205,8 @@ class BratSerializer(DocumentSerializer):
to process documents before serialization.

Attributes:
layers: The names of the annotation layers to serialize.
layers: A mapping from annotation layer names that should be serialized to what should be
serialized, i.e. "gold", "prediction", or "both".
document_processor: A function or callable object to process documents before serialization.
gold_label_prefix: If provided, gold annotations are serialized and its labels are prefixed
with the given string. Otherwise, only predicted annotations are serialized.
Expand All @@ -199,7 +217,7 @@ class BratSerializer(DocumentSerializer):

def __init__(
self,
layers: List[str],
layers: Dict[str, str],
document_processor=None,
prediction_label_prefix=None,
gold_label_prefix=None,
Expand Down Expand Up @@ -230,7 +248,7 @@ def write_with_defaults(self, **kwargs) -> Dict[str, str]:
def write(
cls,
documents: Sequence[Document],
layers: List[str],
layers: Dict[str, str],
path: str,
metadata_file_name: str = METADATA_FILE_NAME,
split: Optional[str] = None,
Expand Down Expand Up @@ -263,7 +281,7 @@ def write(
metadata_text[f"{file_name}"] = doc.text
ann_path = os.path.join(realpath, file_name)
serialized_annotations = serialize_annotation_layers(
layers=[doc[layer] for layer in layers],
layers=[(doc[layer_name], what) for layer_name, what in layers.items()],
gold_label_prefix=gold_label_prefix,
prediction_label_prefix=prediction_label_prefix,
)
Expand Down
184 changes: 142 additions & 42 deletions tests/unit/serializer/test_brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,32 +151,19 @@ def document():
return document


@pytest.fixture
def serialized_annotations(
document,
gold_label_prefix=None,
prediction_label_prefix=None,
):
return serialize_annotation_layers(
layers=[document.labeled_spans, document.binary_relations],
gold_label_prefix=gold_label_prefix,
prediction_label_prefix=prediction_label_prefix,
)


@pytest.mark.parametrize(
"gold_label_prefix, prediction_label_prefix",
[(None, None), ("GOLD", None), (None, "PRED"), ("GOLD", "PRED")],
"what",
["gold", "prediction", "both"],
)
def test_serialize_annotations(document, gold_label_prefix, prediction_label_prefix):
def test_serialize_annotations(document, what):

serialized_annotations = serialize_annotation_layers(
layers=[document.labeled_spans, document.binary_relations],
gold_label_prefix=gold_label_prefix,
prediction_label_prefix=prediction_label_prefix,
layers=[(document.labeled_spans, what), (document.binary_relations, what)],
gold_label_prefix="GOLD",
prediction_label_prefix="PRED" if what == "both" else None,
)

if gold_label_prefix == "GOLD" and prediction_label_prefix == "PRED":
assert len(serialized_annotations) == 8
if what == "both":
assert serialized_annotations == [
"T0\tGOLD-PERSON 0 5\tHarry\n",
"T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n",
Expand All @@ -187,32 +174,45 @@ def test_serialize_annotations(document, gold_label_prefix, prediction_label_pre
"R1\tGOLD-works_at Arg1:T0 Arg2:T2\n",
"R2\tPRED-works_at Arg1:T3 Arg2:T4\n",
]
elif gold_label_prefix == "GOLD" and prediction_label_prefix is None:
assert len(serialized_annotations) == 8
elif what == "gold":
assert serialized_annotations == [
"T0\tGOLD-PERSON 0 5\tHarry\n",
"T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n",
"T2\tGOLD-ORGANIZATION 44 48\tDFKI\n",
"T3\tPERSON 0 5\tHarry\n",
"T4\tORGANIZATION 44 48\tDFKI\n",
"R0\tGOLD-lives_in Arg1:T0 Arg2:T1\n",
"R1\tGOLD-works_at Arg1:T0 Arg2:T2\n",
"R2\tworks_at Arg1:T3 Arg2:T4\n",
]
elif gold_label_prefix is None and prediction_label_prefix == "PRED":
assert len(serialized_annotations) == 3
assert serialized_annotations == [
"T0\tPRED-PERSON 0 5\tHarry\n",
"T1\tPRED-ORGANIZATION 44 48\tDFKI\n",
"R0\tPRED-works_at Arg1:T0 Arg2:T1\n",
]
else:
assert len(serialized_annotations) == 3
elif what == "prediction":
assert serialized_annotations == [
"T0\tPERSON 0 5\tHarry\n",
"T1\tORGANIZATION 44 48\tDFKI\n",
"R0\tworks_at Arg1:T0 Arg2:T1\n",
]
else:
raise ValueError(f"Unexpected value for what: {what}")


def test_serialize_annotations_unknown_what(document):
with pytest.raises(ValueError) as e:
serialize_annotation_layers(
layers=[(document.labeled_spans, "dummy"), (document.binary_relations, "dummy")],
)
assert (
str(e.value)
== 'Invalid value for what to serialize: "dummy". Expected "gold", "prediction", or "both".'
)


def test_serialize_annotations_missing_prefix(document):

with pytest.raises(ValueError) as e:
serialize_annotation_layers(
layers=[(document.labeled_spans, "both")],
)
assert str(e.value) == (
"Cannot serialize both gold and prediction annotations without a label prefix "
"for at least one of them. Consider setting gold_label_prefix or prediction_label_prefix."
)


def document_processor(document) -> TextBasedDocument:
Expand All @@ -221,27 +221,33 @@ def document_processor(document) -> TextBasedDocument:
return doc


def test_write(tmp_path, document, serialized_annotations):
def test_write(tmp_path, document):
path = str(tmp_path)
serializer = BratSerializer(
path=path,
document_processor=document_processor,
layers=["labeled_spans", "binary_relations"],
layers={"labeled_spans": "prediction", "binary_relations": "prediction"},
)

metadata = serializer(documents=[document])
path = metadata["path"]
ann_file = os.path.join(path, f"{document.id}.ann")

with open(ann_file, "r") as file:
for i, line in enumerate(file.readlines()):
assert line == serialized_annotations[i]
file.close()
lines = file.readlines()

assert lines == [
"T0\tPERSON 0 5\tHarry\n",
"T1\tORGANIZATION 44 48\tDFKI\n",
"R0\tworks_at Arg1:T0 Arg2:T1\n",
]


def test_write_with_exceptions_and_warnings(tmp_path, caplog, document):
path = str(tmp_path)
serializer = BratSerializer(path=path, layers=["labeled_spans", "binary_relations"])
serializer = BratSerializer(
path=path, layers={"labeled_spans": "prediction", "binary_relations": "prediction"}
)

# list of empty documents
with pytest.raises(Exception) as e:
Expand Down Expand Up @@ -272,7 +278,9 @@ def test_write_with_exceptions_and_warnings(tmp_path, caplog, document):
def test_write_with_split(tmp_path, document, split):
path = str(tmp_path)
serializer = BratSerializer(
path=path, layers=["labeled_spans", "binary_relations"], split=split
path=path,
layers={"labeled_spans": "prediction", "binary_relations": "prediction"},
split=split,
)

metadata = serializer(documents=[document])
Expand All @@ -281,3 +289,95 @@ def test_write_with_split(tmp_path, document, split):
assert real_path == os.path.join(path)
elif split is not None:
assert real_path == os.path.join(path, split)


@pytest.fixture
def document_only_gold_spans_both_relations():
document = TextDocumentWithLabeledSpansAndBinaryRelations(
text="Harry lives in Berlin, Germany. He works at DFKI.",
id="tmp",
metadata={
"span_ids": [],
"relation_ids": [],
"prediction_span_ids": [],
"prediction_relation_ids": [],
},
)

document.labeled_spans.extend(
[
LabeledSpan(start=0, end=5, label="PERSON"),
LabeledSpan(start=15, end=30, label="LOCATION"),
LabeledSpan(start=44, end=48, label="ORGANIZATION"),
]
)
document.metadata["span_ids"].extend(["T100", "T101", "T102"])

assert str(document.labeled_spans[0]) == "Harry"
assert str(document.labeled_spans[1]) == "Berlin, Germany"
assert str(document.labeled_spans[2]) == "DFKI"

document.binary_relations.predictions.extend(
[
BinaryRelation(
head=document.labeled_spans[0],
tail=document.labeled_spans[2],
label="works_at",
),
]
)
document.metadata["prediction_relation_ids"].extend(["R200"])

document.binary_relations.extend(
[
BinaryRelation(
head=document.labeled_spans[0],
tail=document.labeled_spans[1],
label="lives_in",
),
BinaryRelation(
head=document.labeled_spans[0],
tail=document.labeled_spans[2],
label="works_at",
),
]
)
document.metadata["relation_ids"].extend(["R100", "R101"])

return document


@pytest.mark.parametrize(
"what",
[("gold", "prediction"), ("both", "prediction"), ("gold", "both"), ("both", "both")],
)
def test_serialize_annotations_only_gold_spans_both_relations(
document_only_gold_spans_both_relations, what
):

serialized_annotations = serialize_annotation_layers(
layers=[
(document_only_gold_spans_both_relations.labeled_spans, what[0]),
(document_only_gold_spans_both_relations.binary_relations, what[1]),
],
gold_label_prefix="GOLD",
prediction_label_prefix="PRED",
)
if what in [("gold", "prediction"), ("both", "prediction")]:
assert serialized_annotations == [
"T0\tGOLD-PERSON 0 5\tHarry\n",
"T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n",
"T2\tGOLD-ORGANIZATION 44 48\tDFKI\n",
"R0\tPRED-works_at Arg1:T0 Arg2:T2\n",
]
elif what in [("gold", "both"), ("both", "both")]:
assert serialized_annotations == [
"T0\tGOLD-PERSON 0 5\tHarry\n",
"T1\tGOLD-LOCATION 15 30\tBerlin, Germany\n",
"T2\tGOLD-ORGANIZATION 44 48\tDFKI\n",
"R0\tGOLD-lives_in Arg1:T0 Arg2:T1\n",
"R1\tGOLD-works_at Arg1:T0 Arg2:T2\n",
"R2\tPRED-works_at Arg1:T0 Arg2:T2\n",
]
else:
raise ValueError(f"Unexpected value for what: {what}")
Loading