Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

brat serializer: allow using annotation ids from document metadata #172

Merged
merged 1 commit into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 81 additions & 17 deletions src/serializer/brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,29 @@ def serialize_annotations(
indices: Dict[str, int],
annotation2id: Dict[Annotation, str],
label_prefix: Optional[str] = None,
annotation_ids: Optional[List[str]] = None,
) -> Tuple[List[str], Dict[Annotation, str]]:
serialized_annotations = []
new_annotation2id: Dict[Annotation, str] = {}
for annotation in annotations:
for idx, annotation in enumerate(annotations):
annotation_type, serialized_annotation = serialize_annotation(
annotation=annotation,
annotation2id=annotation2id,
label_prefix=label_prefix,
)
idx = indices[annotation_type]
annotation_id = f"{annotation_type}{idx}"
if annotation_ids is not None:
if indices.get(annotation_type, 0) > 0:
raise ValueError(
"Cannot specify annotation IDs for the same type (e.g. T or R) if there are "
"other annotations of the same type without an ID."
)
annotation_id = annotation_ids[idx]
else:
index = indices[annotation_type]
annotation_id = f"{annotation_type}{index}"
indices[annotation_type] += 1
serialized_annotations.append(f"{annotation_id}\t{serialized_annotation}")
new_annotation2id[annotation] = annotation_id
indices[annotation_type] += 1

return serialized_annotations, new_annotation2id

Expand All @@ -135,6 +144,8 @@ def serialize_annotation_layers(
layers: List[Tuple[AnnotationLayer, str]],
gold_label_prefix: Optional[str] = None,
prediction_label_prefix: Optional[str] = None,
gold_annotation_ids: Optional[List[Optional[List[str]]]] = None,
prediction_annotation_ids: Optional[List[Optional[List[str]]]] = None,
) -> List[str]:
"""Serialize annotations from given annotation layers into a list of strings.

Expand All @@ -145,15 +156,20 @@ def serialize_annotation_layers(
Defaults to None.
prediction_label_prefix (Optional[str], optional): Prefix to be added to prediction labels.
Defaults to None.
gold_annotation_ids (Optional[List[Optional[str]]], optional): List of gold annotation IDs.
If provided, the length should match the number of layers. Defaults to None.
prediction_annotation_ids (Optional[List[Optional[str]]], optional): List of prediction
annotation IDs. If provided, the length should match the number of layers. Defaults to None.

Returns:
List[str]: List of serialized annotations.
"""

all_serialized_annotations = []
gold_annotation2id: Dict[Annotation, str] = {}
prediction_annotation2id: Dict[Annotation, str] = {}
indices: Dict[str, int] = defaultdict(int)
for layer, what in layers:
for idx, (layer, what) in enumerate(layers):
if what not in ["gold", "prediction", "both"]:
raise ValueError(
f'Invalid value for what to serialize: "{what}". Expected "gold", "prediction", or "both".'
Expand All @@ -171,23 +187,54 @@ def serialize_annotation_layers(
)
serialized_annotations = []
if what in ["gold", "both"]:
if gold_annotation_ids is not None:
if len(gold_annotation_ids) <= idx:
raise ValueError(
"gold_annotation_ids should have the same length as the number of layers."
)
current_gold_annotation_ids = gold_annotation_ids[idx]
if current_gold_annotation_ids is not None and len(
current_gold_annotation_ids
) != len(layer):
raise ValueError(
"gold_annotation_ids should have the same length as the number of gold annotations."
)
else:
current_gold_annotation_ids = None

serialized_gold_annotations, new_gold_ann2id = serialize_annotations(
annotations=layer,
indices=indices,
# gold annotations can only reference other gold annotations
annotation2id=gold_annotation2id,
label_prefix=gold_label_prefix,
annotation_ids=current_gold_annotation_ids,
)
serialized_annotations.extend(serialized_gold_annotations)
gold_annotation2id.update(new_gold_ann2id)
if what in ["prediction", "both"]:
if prediction_annotation_ids is not None:
if len(prediction_annotation_ids) <= idx:
raise ValueError(
"prediction_annotation_ids should have the same length as the number of layers."
)
current_prediction_annotation_ids = prediction_annotation_ids[idx]
if current_prediction_annotation_ids is not None and len(
current_prediction_annotation_ids
) != len(layer.predictions):
raise ValueError(
"prediction_annotation_ids should have the same length as the number of prediction annotations."
)
else:
current_prediction_annotation_ids = None
serialized_predicted_annotations, new_pred_ann2id = serialize_annotations(
annotations=layer.predictions,
indices=indices,
# Predicted annotations can reference both gold and predicted annotations.
# Note that predictions take precedence over gold annotations.
annotation2id={**gold_annotation2id, **prediction_annotation2id},
label_prefix=prediction_label_prefix,
annotation_ids=current_prediction_annotation_ids,
)
prediction_annotation2id.update(new_pred_ann2id)
serialized_annotations.extend(serialized_predicted_annotations)
Expand All @@ -200,10 +247,6 @@ class BratSerializer(DocumentSerializer):
specify the annotation layers to serialize. For now, it supports layers containing LabeledSpan,
LabeledMultiSpan, and BinaryRelation annotations.

If a gold_label_prefix is provided, the gold annotations are serialized with the given prefix.
Otherwise, only the predicted annotations are serialized. A document_processor can be provided
to process documents before serialization.

Attributes:
layers: A mapping from annotation layer names that should be serialized to what should be
serialized, i.e. "gold", "prediction", or "both".
Expand All @@ -212,21 +255,20 @@ class BratSerializer(DocumentSerializer):
with the given string. Otherwise, only predicted annotations are serialized.
prediction_label_prefix: If provided, labels of predicted annotations are prefixed with the
given string.
default_kwargs: Additional keyword arguments to be used as defaults during serialization.
metadata_gold_id_keys: A dictionary mapping layer names to metadata keys that contain the
gold annotation IDs.
metadata_prediction_id_keys: A dictionary mapping layer names to metadata keys that contain
the prediction annotation IDs.
"""

def __init__(
self,
layers: Dict[str, str],
document_processor=None,
prediction_label_prefix=None,
gold_label_prefix=None,
**kwargs,
):
self.document_processor = document_processor
self.layers = layers
self.prediction_label_prefix = prediction_label_prefix
self.gold_label_prefix = gold_label_prefix
self.default_kwargs = kwargs

def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]:
Expand All @@ -235,8 +277,6 @@ def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]:
return self.write_with_defaults(
documents=documents,
layers=self.layers,
prediction_label_prefix=self.prediction_label_prefix,
gold_label_prefix=self.gold_label_prefix,
**kwargs,
)

Expand All @@ -254,6 +294,8 @@ def write(
split: Optional[str] = None,
gold_label_prefix: Optional[str] = None,
prediction_label_prefix: Optional[str] = None,
metadata_gold_id_keys: Optional[Dict[str, str]] = None,
metadata_prediction_id_keys: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:

realpath = os.path.realpath(path)
Expand All @@ -280,10 +322,32 @@ def write(
file_name = f"{doc_id}.ann"
metadata_text[f"{file_name}"] = doc.text
ann_path = os.path.join(realpath, file_name)
layer_names = list(layers)
if metadata_gold_id_keys is not None:
gold_annotation_ids = [
doc.metadata[metadata_gold_id_keys[layer_name]]
if layer_name in metadata_gold_id_keys
else None
for layer_name in layer_names
]
else:
gold_annotation_ids = None

if metadata_prediction_id_keys is not None:
prediction_annotation_ids = [
doc.metadata[metadata_prediction_id_keys[layer_name]]
if layer_name in metadata_prediction_id_keys
else None
for layer_name in layer_names
]
else:
prediction_annotation_ids = None
serialized_annotations = serialize_annotation_layers(
layers=[(doc[layer_name], what) for layer_name, what in layers.items()],
layers=[(doc[layer_name], layers[layer_name]) for layer_name in layer_names],
gold_label_prefix=gold_label_prefix,
prediction_label_prefix=prediction_label_prefix,
gold_annotation_ids=gold_annotation_ids,
prediction_annotation_ids=prediction_annotation_ids,
)
with open(ann_path, "w+") as f:
f.writelines(serialized_annotations)
Expand Down
86 changes: 83 additions & 3 deletions tests/unit/serializer/test_brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,22 @@ class TextDocumentWithLabeledMultiSpansAndBinaryRelations(TextBasedDocument):
@pytest.fixture
def document():
document = TextDocumentWithLabeledSpansAndBinaryRelations(
text="Harry lives in Berlin, Germany. He works at DFKI.", id="tmp"
text="Harry lives in Berlin, Germany. He works at DFKI.",
id="tmp",
metadata={
"span_ids": [],
"relation_ids": [],
"prediction_span_ids": [],
"prediction_relation_ids": [],
},
)
document.labeled_spans.predictions.extend(
[
LabeledSpan(start=0, end=5, label="PERSON"),
LabeledSpan(start=44, end=48, label="ORGANIZATION"),
]
)
document.metadata["prediction_span_ids"].extend(["T200", "T201"])

assert str(document.labeled_spans.predictions[0]) == "Harry"
assert str(document.labeled_spans.predictions[1]) == "DFKI"
Expand All @@ -119,6 +127,8 @@ def document():
LabeledSpan(start=44, end=48, label="ORGANIZATION"),
]
)
document.metadata["span_ids"].extend(["T100", "T101", "T102"])

assert str(document.labeled_spans[0]) == "Harry"
assert str(document.labeled_spans[1]) == "Berlin, Germany"
assert str(document.labeled_spans[2]) == "DFKI"
Expand All @@ -132,6 +142,7 @@ def document():
),
]
)
document.metadata["prediction_relation_ids"].extend(["R200"])

document.binary_relations.extend(
[
Expand All @@ -147,6 +158,7 @@ def document():
),
]
)
document.metadata["relation_ids"].extend(["R100", "R101"])

return document

Expand Down Expand Up @@ -192,6 +204,51 @@ def test_serialize_annotations(document, what):
raise ValueError(f"Unexpected value for what: {what}")


@pytest.mark.parametrize(
"what",
["gold", "prediction", "both"],
)
def test_serialize_annotations_with_annotation_ids(document, what):
serialized_annotations = serialize_annotation_layers(
layers=[(document.labeled_spans, what), (document.binary_relations, what)],
gold_label_prefix="GOLD",
prediction_label_prefix="PRED" if what == "both" else None,
gold_annotation_ids=[document.metadata["span_ids"], document.metadata["relation_ids"]],
prediction_annotation_ids=[
document.metadata["prediction_span_ids"],
document.metadata["prediction_relation_ids"],
],
)

if what == "both":
assert serialized_annotations == [
"T100\tGOLD-PERSON 0 5\tHarry\n",
"T101\tGOLD-LOCATION 15 30\tBerlin, Germany\n",
"T102\tGOLD-ORGANIZATION 44 48\tDFKI\n",
"T200\tPRED-PERSON 0 5\tHarry\n",
"T201\tPRED-ORGANIZATION 44 48\tDFKI\n",
"R100\tGOLD-lives_in Arg1:T100 Arg2:T101\n",
"R101\tGOLD-works_at Arg1:T100 Arg2:T102\n",
"R200\tPRED-works_at Arg1:T200 Arg2:T201\n",
]
elif what == "gold":
assert serialized_annotations == [
"T100\tGOLD-PERSON 0 5\tHarry\n",
"T101\tGOLD-LOCATION 15 30\tBerlin, Germany\n",
"T102\tGOLD-ORGANIZATION 44 48\tDFKI\n",
"R100\tGOLD-lives_in Arg1:T100 Arg2:T101\n",
"R101\tGOLD-works_at Arg1:T100 Arg2:T102\n",
]
elif what == "prediction":
assert serialized_annotations == [
"T200\tPERSON 0 5\tHarry\n",
"T201\tORGANIZATION 44 48\tDFKI\n",
"R200\tworks_at Arg1:T200 Arg2:T201\n",
]
else:
raise ValueError(f"Unexpected value for what: {what}")


def test_serialize_annotations_unknown_what(document):
with pytest.raises(ValueError) as e:
serialize_annotation_layers(
Expand All @@ -215,7 +272,7 @@ def test_serialize_annotations_missing_prefix(document):
)


def document_processor(document) -> TextBasedDocument:
def append_empty_span_to_labeled_spans(document) -> TextBasedDocument:
doc = document.copy()
doc["labeled_spans"].append(LabeledSpan(start=0, end=0, label="empty"))
return doc
Expand All @@ -225,7 +282,7 @@ def test_write(tmp_path, document):
path = str(tmp_path)
serializer = BratSerializer(
path=path,
document_processor=document_processor,
document_processor=append_empty_span_to_labeled_spans,
layers={"labeled_spans": "prediction", "binary_relations": "prediction"},
)

Expand All @@ -243,6 +300,29 @@ def test_write(tmp_path, document):
]


def test_write_with_annotation_ids(tmp_path, document):
path = str(tmp_path)
serializer = BratSerializer(
path=path,
layers={"labeled_spans": "gold", "binary_relations": "prediction"},
metadata_gold_id_keys={"labeled_spans": "span_ids"},
)

metadata = serializer(documents=[document])
path = metadata["path"]
ann_file = os.path.join(path, f"{document.id}.ann")

with open(ann_file, "r") as file:
lines = file.readlines()

assert lines == [
"T100\tPERSON 0 5\tHarry\n",
"T101\tLOCATION 15 30\tBerlin, Germany\n",
"T102\tORGANIZATION 44 48\tDFKI\n",
"R0\tworks_at Arg1:T100 Arg2:T102\n",
]


def test_write_with_exceptions_and_warnings(tmp_path, caplog, document):
path = str(tmp_path)
serializer = BratSerializer(
Expand Down
Loading