Skip to content

Commit

Permalink
add test for gold and prediction only annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
Bhuvanesh-Verma committed Feb 6, 2024
1 parent bdbc3ad commit fc439e1
Showing 1 changed file with 220 additions and 61 deletions.
281 changes: 220 additions & 61 deletions tests/unit/serializer/test_brat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os
from collections import defaultdict as dd
from dataclasses import dataclass
from typing import TypeVar

from pie_datasets.builders.brat import BratDocument
from pie_modules.document.processing import SpansViaRelationMerger
from pytorch_ie import AnnotationLayer
from pytorch_ie.core import Document

from src.utils import get_pylogger
Expand All @@ -9,102 +14,256 @@

D = TypeVar("D", bound=Document)
import pytest
from pie_datasets import DatasetDict
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.documents import TextDocument
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
from pytorch_ie.core import annotation_field
from pytorch_ie.documents import TextBasedDocument, TextDocumentWithLabeledSpansAndBinaryRelations

from src.serializer import BratSerializer


def get_location(location_string):
parts = location_string.split(" ")
assert (
len(parts) == 2
), f"Wrong number of entries in location string. Expected 2, but found: {parts}"
return {"start": int(parts[0]), "end": int(parts[1])}


def get_span_annotation(annotation_line):
"""example input:
T1 Organization 0 4 Sony
"""

_id, remaining, text = annotation_line.split("\t", maxsplit=2)
_type, locations = remaining.split(" ", maxsplit=1)
return {
"id": _id,
"text": text,
"type": _type,
"locations": [get_location(loc) for loc in locations.split(";")],
}


def get_relation_annotation(annotation_line):
"""example input:
R1 Origin Arg1:T3 Arg2:T4
"""

_id, remaining = annotation_line.strip().split("\t")
_type, remaining = remaining.split(" ", maxsplit=1)
args = [dict(zip(["type", "target"], a.split(":"))) for a in remaining.split(" ")]
return {"id": _id, "type": _type, "arguments": args}


def read_annotation_file(filename):
res = {
"spans": [],
"relations": [],
}
with open(filename, encoding="utf-8") as file:
for i, line in enumerate(file):
if len(line.strip()) == 0:
continue
ann_type = line[0]
# strip away the new line character
if line.endswith("\n"):
line = line[:-1]
if ann_type == "T":
res["spans"].append(get_span_annotation(line))
elif ann_type == "R":
res["relations"].append(get_relation_annotation(line))
else:
raise ValueError(
f'unknown BRAT annotation id type: "{line}" (from file {filename} @line {i}). '
f"Annotation ids have to start with T (spans), E (events), R (relations), "
f"A (attributions), or N (normalizations). See "
f"https://brat.nlplab.org/standoff.html for the BRAT annotation file "
f"specification."
)
return res


@dataclass
class ExampleDocument(TextDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
class TextDocumentWithLabeledMultiSpansAndBinaryRelations(TextBasedDocument):
labeled_spans: AnnotationLayer[LabeledMultiSpan] = annotation_field(target="text")
binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(target="labeled_spans")


@pytest.fixture
def document_processor():
dp = SpansViaRelationMerger(
relation_layer="binary_relations",
link_relation_label="parts_of_same",
result_document_type=BratDocument,
result_field_mapping={"labeled_spans": "spans", "binary_relations": "relations"},
)

return dp


@pytest.fixture
def document():
document = ExampleDocument(text="Harry lives in Berlin. He works at DFKI.", id="tmp")
entities = [
def document_with_gold_only():
document = TextDocumentWithLabeledSpansAndBinaryRelations(
text="Harry lives in Berlin. He works at DFKI.", id="tmp"
)
labeled_spans = [
LabeledSpan(start=0, end=5, label="PERSON"),
LabeledSpan(start=15, end=21, label="LOCATION"),
LabeledSpan(start=35, end=39, label="ORGANIZATION"),
]
for ent in entities:
document.entities.predictions.append(ent)
for ent in labeled_spans:
document.labeled_spans.append(ent)

relations = [
BinaryRelation(head=entities[0], tail=entities[1], label="lives_in"),
BinaryRelation(head=entities[0], tail=entities[2], label="works_at"),
binary_relations = [
BinaryRelation(head=labeled_spans[0], tail=labeled_spans[1], label="lives_in"),
BinaryRelation(head=labeled_spans[0], tail=labeled_spans[2], label="works_at"),
]

# add relations as predictions
for rel in relations:
document.relations.predictions.append(rel)
for rel in binary_relations:
document.binary_relations.append(rel)

return document


def test_save_gold_only(tmp_path, document_with_gold_only, document_processor):
path = str(tmp_path)
serializer = BratSerializer(
path=path, document_processor=document_processor, gold_label_prefix="GOLD"
)

metadata = serializer(documents=[document_with_gold_only])

path = metadata["path"]
res = read_annotation_file(os.path.join(path, f"{document_with_gold_only.id}.ann"))
"""
res in the following format:
{'spans':
[
{'id': 'T0', 'text': 'Harry', 'type': 'GOLD-PERSON', 'locations': [{'start': 0, 'end': 5}]},
{'id': 'T1', 'text': 'DFKI', 'type': 'GOLD-ORGANIZATION', 'locations': [{'start': 35, 'end': 39}]},
{'id': 'T2', 'text': 'Berlin', 'type': 'GOLD-LOCATION', 'locations': [{'start': 15, 'end': 21}]}
],
'relations':
[
{'id': 'R0', 'type': 'lives_in', 'arguments': [{'type': 'Arg1', 'target': 'T0'},
{'type': 'Arg2', 'target': 'T2'}]},
{'id': 'R1', 'type': 'works_at', 'arguments': [{'type': 'Arg1', 'target': 'T0'},
{'type': 'Arg2', 'target': 'T1'}]}
]
}
"""
spans = res["spans"]
original_spans = document_with_gold_only.labeled_spans
assert len(spans) == len(original_spans)

sorted_spans = sorted(spans, key=lambda x: x["locations"][0]["start"])
sorted_original_spans = sorted(original_spans, key=lambda x: x.start)
span2spanid = dd()
for span, original_span in zip(sorted_spans, sorted_original_spans):
assert span["locations"][0]["start"] == original_span.start
assert span["locations"][0]["end"] == original_span.end
assert span["type"] == f"GOLD-{original_span.label}"
assert (
span["text"] == document_with_gold_only.text[original_span.start : original_span.end]
)

span2spanid[original_span] = span["id"]

relations = res["relations"]
original_relations = document_with_gold_only.binary_relations
assert len(relations) == len(original_relations)

sorted_relations = sorted(relations, key=lambda x: x["type"])
sorted_original_relations = sorted(original_relations, key=lambda x: x.label)

for relation, original_relation in zip(sorted_relations, sorted_original_relations):
assert relation["type"] == original_relation.label
assert relation["arguments"][0]["target"] == span2spanid[original_relation.head]
assert relation["arguments"][1]["target"] == span2spanid[original_relation.tail]


@pytest.fixture
def document_with_multispan():
document = ExampleDocument(text="Harry lives in Berlin, Germany. He works at DFKI.", id="tmp")
entities = [
def document_with_prediction_only():
document = TextDocumentWithLabeledSpansAndBinaryRelations(
text="Harry lives in Berlin. He works at DFKI.", id="tmp"
)
labeled_spans = [
LabeledSpan(start=0, end=5, label="PERSON"),
LabeledSpan(start=15, end=21, label="LOCATION"),
LabeledSpan(start=23, end=30, label="LOCATION"),
LabeledSpan(start=44, end=48, label="ORGANIZATION"),
LabeledSpan(start=35, end=39, label="ORGANIZATION"),
]
for ent in entities:
document.entities.predictions.append(ent)

relations = [
BinaryRelation(head=entities[0], tail=entities[1], label="lives_in"),
BinaryRelation(head=entities[1], tail=entities[2], label="part_of_same"),
BinaryRelation(head=entities[0], tail=entities[3], label="works_at"),
# TODO: can have one or both relations below ?
# BinaryRelation(head=entities[3], tail=entities[1], label="located_in"),
# BinaryRelation(head=entities[3], tail=entities[2], label="located_in"),

# add entities as predictions
for ent in labeled_spans:
document.labeled_spans.predictions.append(ent)

binary_relations = [
BinaryRelation(head=labeled_spans[0], tail=labeled_spans[1], label="lives_in"),
BinaryRelation(head=labeled_spans[0], tail=labeled_spans[2], label="works_at"),
]

# add relations as predictions
for rel in relations:
document.relations.predictions.append(rel)
for rel in binary_relations:
document.binary_relations.predictions.append(rel)

return document


def test_save_and_load(tmp_path, document):
def test_save_prediction_only(tmp_path, document_with_prediction_only, document_processor):
path = str(tmp_path)
serializer = BratSerializer(path=path)
serializer = BratSerializer(
path=path, document_processor=document_processor, prediction_label_prefix="PRED"
)

serializer(documents=[document])
metadata = serializer(documents=[document_with_prediction_only])

loaded_document = serializer.read_with_defaults()[0]
assert loaded_document.text == document.text
path = metadata["path"]
res = read_annotation_file(os.path.join(path, f"{document_with_prediction_only.id}.ann"))
"""
res in the following format:
{'spans':
[
{'id': 'T0', 'text': 'Harry', 'type': 'PRED-PERSON', 'locations': [{'start': 0, 'end': 5}]},
{'id': 'T1', 'text': 'DFKI', 'type': 'PRED-ORGANIZATION', 'locations': [{'start': 35, 'end': 39}]},
{'id': 'T2', 'text': 'Berlin', 'type': 'PRED-LOCATION', 'locations': [{'start': 15, 'end': 21}]}
],
'relations':
[
{'id': 'R0', 'type': 'lives_in', 'arguments': [{'type': 'Arg1', 'target': 'T0'},
{'type': 'Arg2', 'target': 'T2'}]},
{'id': 'R1', 'type': 'works_at', 'arguments': [{'type': 'Arg1', 'target': 'T0'},
{'type': 'Arg2', 'target': 'T1'}]}
]
}
"""
spans = res["spans"]
original_spans = document_with_prediction_only.labeled_spans.predictions
assert len(spans) == len(original_spans)

entities = document.entities.predictions
loaded_entities = loaded_document.entities.predictions
assert loaded_entities == entities

relations = document.relations.predictions
loaded_relations = loaded_document.relations.predictions
assert loaded_relations == relations


def test_save_and_load_multispan(tmp_path, document_with_multispan):
path = str(tmp_path)
serializer = BratSerializer(path=path)
sorted_spans = sorted(spans, key=lambda x: x["locations"][0]["start"])
sorted_original_spans = sorted(original_spans, key=lambda x: x.start)
span2spanid = dd()
for span, original_span in zip(sorted_spans, sorted_original_spans):
assert span["locations"][0]["start"] == original_span.start
assert span["locations"][0]["end"] == original_span.end
assert span["type"] == f"PRED-{original_span.label}"
assert (
span["text"]
== document_with_prediction_only.text[original_span.start : original_span.end]
)

serializer(documents=[document_with_multispan])
span2spanid[original_span] = span["id"]

loaded_document = serializer.read_with_defaults()[0]
assert loaded_document.text == document_with_multispan.text
relations = res["relations"]
original_relations = document_with_prediction_only.binary_relations.predictions
assert len(relations) == len(original_relations)

entities = document_with_multispan.entities.predictions
loaded_entities = loaded_document.entities.predictions
assert loaded_entities == entities
sorted_relations = sorted(relations, key=lambda x: x["type"])
sorted_original_relations = sorted(original_relations, key=lambda x: x.label)

relations = document_with_multispan.relations.predictions
loaded_relations = loaded_document.relations.predictions
assert loaded_relations == relations
for relation, original_relation in zip(sorted_relations, sorted_original_relations):
assert relation["type"] == original_relation.label
assert relation["arguments"][0]["target"] == span2spanid[original_relation.head]
assert relation["arguments"][1]["target"] == span2spanid[original_relation.tail]

0 comments on commit fc439e1

Please sign in to comment.