diff --git a/tests/unit/serializer/test_brat.py b/tests/unit/serializer/test_brat.py index 4a26eb6..a50dc2f 100644 --- a/tests/unit/serializer/test_brat.py +++ b/tests/unit/serializer/test_brat.py @@ -104,166 +104,363 @@ def document_processor(): @pytest.fixture -def document_with_gold_only(): +def document(): document = TextDocumentWithLabeledSpansAndBinaryRelations( - text="Harry lives in Berlin. He works at DFKI.", id="tmp" + text="Harry lives in Berlin. He works at DFKI.", id="tmp_1" ) + predicted_labeled_spans = [ + LabeledSpan(start=0, end=5, label="PERSON"), # Harry + LabeledSpan(start=15, end=21, label="LOCATION"), # Berlin + ] + document.labeled_spans.predictions.extend(predicted_labeled_spans) + labeled_spans = [ - LabeledSpan(start=0, end=5, label="PERSON"), - LabeledSpan(start=15, end=21, label="LOCATION"), - LabeledSpan(start=35, end=39, label="ORGANIZATION"), + LabeledSpan(start=0, end=5, label="PERSON"), # Harry + LabeledSpan(start=15, end=21, label="LOCATION"), # Berlin + LabeledSpan(start=35, end=39, label="ORGANIZATION"), # DFKI + ] + + document.labeled_spans.extend(labeled_spans) + + predicted_binary_relations = [ + BinaryRelation(head=labeled_spans[0], tail=labeled_spans[1], label="lives_in"), ] - for ent in labeled_spans: - document.labeled_spans.append(ent) + document.binary_relations.predictions.extend(predicted_binary_relations) binary_relations = [ BinaryRelation(head=labeled_spans[0], tail=labeled_spans[1], label="lives_in"), BinaryRelation(head=labeled_spans[0], tail=labeled_spans[2], label="works_at"), ] - for rel in binary_relations: - document.binary_relations.append(rel) + document.binary_relations.extend(binary_relations) return document -def test_save_gold_only(tmp_path, document_with_gold_only, document_processor): +def test_save(tmp_path, document, document_processor): path = str(tmp_path) serializer = BratSerializer( - path=path, document_processor=document_processor, gold_label_prefix="GOLD" + path=path, + document_processor=document_processor, + entity_layer="spans", + relation_layer="relations", ) - metadata = serializer(documents=[document_with_gold_only]) + metadata = serializer(documents=[document]) path = metadata["path"] - res = read_annotation_file(os.path.join(path, f"{document_with_gold_only.id}.ann")) - """ - res in the following format: - {'spans': - [ - {'id': 'T0', 'text': 'Harry', 'type': 'GOLD-PERSON', 'locations': [{'start': 0, 'end': 5}]}, - {'id': 'T1', 'text': 'DFKI', 'type': 'GOLD-ORGANIZATION', 'locations': [{'start': 35, 'end': 39}]}, - {'id': 'T2', 'text': 'Berlin', 'type': 'GOLD-LOCATION', 'locations': [{'start': 15, 'end': 21}]} - ], - 'relations': - [ - {'id': 'R0', 'type': 'lives_in', 'arguments': [{'type': 'Arg1', 'target': 'T0'}, - {'type': 'Arg2', 'target': 'T2'}]}, - {'id': 'R1', 'type': 'works_at', 'arguments': [{'type': 'Arg1', 'target': 'T0'}, - {'type': 'Arg2', 'target': 'T1'}]} - ] - } + res = read_annotation_file(os.path.join(path, f"{document.id}.ann")) """ + res in the following format: + {'spans': + [ + {'id': 'T0', 'text': 'Harry', 'type': 'PERSON', 'locations': [{'start': 0, 'end': 5}]}, + {'id': 'T1', 'text': 'DFKI', 'type': 'ORGANIZATION', 'locations': [{'start': 35, 'end': 39}]}, + {'id': 'T2', 'text': 'Berlin', 'type': 'LOCATION', 'locations': [{'start': 15, 'end': 21}]} + ], + 'relations': + [ + {'id': 'R0', 'type': 'lives_in', 'arguments': [{'type': 'Arg1', 'target': 'T0'}, + {'type': 'Arg2', 'target': 'T2'}]}, + {'id': 'R1', 'type': 'works_at', 'arguments': [{'type': 'Arg1', 'target': 'T0'}, + {'type': 'Arg2', 'target': 'T1'}]} + ] + } + """ spans = res["spans"] - original_spans = document_with_gold_only.labeled_spans + original_spans = document.labeled_spans.predictions assert len(spans) == len(original_spans) - sorted_spans = sorted(spans, key=lambda x: x["locations"][0]["start"]) - sorted_original_spans = sorted(original_spans, key=lambda x: x.start) - span2spanid = dd() + sorted_spans = sorted( + spans, key=lambda x: x["locations"][0]["start"] + ) # sort by start index of first span + sorted_original_spans = sorted(original_spans, key=lambda x: x.start) # sort by start index + spanid2span = dd() # map span_id (T0,T1,..) to original span for span, original_span in zip(sorted_spans, sorted_original_spans): assert span["locations"][0]["start"] == original_span.start assert span["locations"][0]["end"] == original_span.end - assert span["type"] == f"GOLD-{original_span.label}" - assert ( - span["text"] == document_with_gold_only.text[original_span.start : original_span.end] - ) + assert span["type"] == original_span.label + assert span["text"] == document.text[original_span.start : original_span.end] - span2spanid[original_span] = span["id"] + spanid2span[span["id"]] = original_span relations = res["relations"] - original_relations = document_with_gold_only.binary_relations + original_relations = document.binary_relations.predictions assert len(relations) == len(original_relations) - sorted_relations = sorted(relations, key=lambda x: x["type"]) + sorted_relations = sorted(relations, key=lambda x: x["type"]) # sort by relation label sorted_original_relations = sorted(original_relations, key=lambda x: x.label) for relation, original_relation in zip(sorted_relations, sorted_original_relations): assert relation["type"] == original_relation.label - assert relation["arguments"][0]["target"] == span2spanid[original_relation.head] - assert relation["arguments"][1]["target"] == span2spanid[original_relation.tail] + assert spanid2span[relation["arguments"][0]["target"]] == original_relation.head + assert spanid2span[relation["arguments"][1]["target"]] == original_relation.tail + + +def test_save_gold_annotation_with_prefix(tmp_path, document, document_processor): + path = str(tmp_path) + serializer = BratSerializer( + path=path, + document_processor=document_processor, + entity_layer="spans", + relation_layer="relations", + gold_label_prefix="GOLD", + ) + + metadata = serializer(documents=[document]) + + path = metadata["path"] + res = read_annotation_file(os.path.join(path, f"{document.id}.ann")) + """ + res in the following format: + { + 'spans': + [ + {'id': 'T0', 'text': 'Berlin', 'type': 'GOLD-LOCATION', 'locations': [{'start': 15, 'end': 21}]}, + {'id': 'T1', 'text': 'Harry', 'type': 'GOLD-PERSON', 'locations': [{'start': 0, 'end': 5}]}, + {'id': 'T2', 'text': 'DFKI', 'type': 'GOLD-ORGANIZATION', 'locations': [{'start': 35, 'end': 39}]}, + {'id': 'T3', 'text': 'Berlin', 'type': 'LOCATION', 'locations': [{'start': 15, 'end': 21}]}, + {'id': 'T4', 'text': 'Harry', 'type': 'PERSON', 'locations': [{'start': 0, 'end': 5}]} + ], + 'relations': + [ + {'id': 'R0', 'type': 'GOLD-lives_in', 'arguments': [{'type': 'Arg1', 'target': 'T1'}, + {'type': 'Arg2', 'target': 'T0'}]}, + {'id': 'R1', 'type': 'GOLD-works_at', 'arguments': [{'type': 'Arg1', 'target': 'T1'}, + {'type': 'Arg2', 'target': 'T2'}]}, + {'id': 'R3', 'type': 'lives_in', 'arguments': [{'type': 'Arg1', 'target': 'T4'}, + {'type': 'Arg2', 'target': 'T3'}]} + ] + } + + """ + spans = res["spans"] + sorted_spans = sorted(spans, key=lambda x: x["type"])[ + :3 + ] # only first three are gold annotations + original_spans = document.labeled_spans + sorted_original_spans = sorted(original_spans, key=lambda x: x.label) + assert len(sorted_spans) == len(sorted_original_spans) + + spanid2span = dd() + for span, original_span in zip(sorted_spans, sorted_original_spans): + assert span["locations"][0]["start"] == original_span.start + assert span["locations"][0]["end"] == original_span.end + assert span["type"] == f"GOLD-{original_span.label}" + assert span["text"] == document.text[original_span.start : original_span.end] + + spanid2span[span["id"]] = original_span + + relations = res["relations"] + sorted_relations = sorted(relations, key=lambda x: x["type"])[ + :2 + ] # only first two are gold annotations + original_relations = document.binary_relations + sorted_original_relations = sorted(original_relations, key=lambda x: x.label) + assert len(sorted_relations) == len(sorted_original_relations) + + for relation, original_relation in zip(sorted_relations, sorted_original_relations): + assert relation["type"] == f"GOLD-{original_relation.label}" + assert spanid2span[relation["arguments"][0]["target"]] == original_relation.head + assert spanid2span[relation["arguments"][1]["target"]] == original_relation.tail + + +def test_save_gold_and_predicted_annotation_with_prefix(tmp_path, document, document_processor): + path = str(tmp_path) + serializer = BratSerializer( + path=path, + document_processor=document_processor, + entity_layer="spans", + relation_layer="relations", + gold_label_prefix="GOLD", + prediction_label_prefix="PRED", + ) + + metadata = serializer(documents=[document]) + + path = metadata["path"] + res = read_annotation_file(os.path.join(path, f"{document.id}.ann")) + """ + res in the following format: + { + 'spans': + [ + {'id': 'T0', 'text': 'Berlin', 'type': 'GOLD-LOCATION', 'locations': [{'start': 15, 'end': 21}]}, + {'id': 'T1', 'text': 'Harry', 'type': 'GOLD-PERSON', 'locations': [{'start': 0, 'end': 5}]}, + {'id': 'T2', 'text': 'DFKI', 'type': 'GOLD-ORGANIZATION', 'locations': [{'start': 35, 'end': 39}]}, + {'id': 'T3', 'text': 'Berlin', 'type': 'PRED-LOCATION', 'locations': [{'start': 15, 'end': 21}]}, + {'id': 'T4', 'text': 'Harry', 'type': 'PRED-PERSON', 'locations': [{'start': 0, 'end': 5}]} + ], + 'relations': + [ + {'id': 'R0', 'type': 'GOLD-lives_in', 'arguments': [{'type': 'Arg1', 'target': 'T1'}, + {'type': 'Arg2', 'target': 'T0'}]}, + {'id': 'R1', 'type': 'GOLD-works_at', 'arguments': [{'type': 'Arg1', 'target': 'T1'}, + {'type': 'Arg2', 'target': 'T2'}]}, + {'id': 'R3', 'type': 'PRED-lives_in', 'arguments': [{'type': 'Arg1', 'target': 'T4'}, + {'type': 'Arg2', 'target': 'T3'}]} + ] + } + + """ + spans = res["spans"] + sorted_spans = sorted(spans, key=lambda x: x["type"]) + + gold_spans = sorted_spans[:3] # only first three are gold annotations + original_gold_spans = document.labeled_spans + sorted_original_gold_spans = sorted(original_gold_spans, key=lambda x: x.label) + assert len(gold_spans) == len(sorted_original_gold_spans) + + spanid2span = dd() + for span, original_span in zip(gold_spans, sorted_original_gold_spans): + assert span["locations"][0]["start"] == original_span.start + assert span["locations"][0]["end"] == original_span.end + assert span["type"] == f"GOLD-{original_span.label}" + assert span["text"] == document.text[original_span.start : original_span.end] + + spanid2span[span["id"]] = original_span + + predicted_spans = sorted_spans[3:] # last two are predicted annotations + original_predicted_spans = document.labeled_spans.predictions + sorted_original_predicted_spans = sorted(original_predicted_spans, key=lambda x: x.label) + assert len(predicted_spans) == len(sorted_original_predicted_spans) + + for span, original_span in zip(predicted_spans, sorted_original_predicted_spans): + assert span["locations"][0]["start"] == original_span.start + assert span["locations"][0]["end"] == original_span.end + assert span["type"] == f"PRED-{original_span.label}" + assert span["text"] == document.text[original_span.start : original_span.end] + + spanid2span[span["id"]] = original_span + + relations = res["relations"] + sorted_relations = sorted(relations, key=lambda x: x["type"]) + + gold_relations = sorted_relations[:2] # only first two are gold annotations + original_gold_relations = document.binary_relations + sorted_original_relations = sorted(original_gold_relations, key=lambda x: x.label) + assert len(gold_relations) == len(sorted_original_relations) + + for relation, original_relation in zip(gold_relations, sorted_original_relations): + assert relation["type"] == f"GOLD-{original_relation.label}" + assert spanid2span[relation["arguments"][0]["target"]] == original_relation.head + assert spanid2span[relation["arguments"][1]["target"]] == original_relation.tail + + predicted_relations = sorted_relations[2:] # only last annotation is predicted annotation + original_predicted_relations = document.binary_relations.predictions + sorted_original_relations = sorted(original_predicted_relations, key=lambda x: x.label) + assert len(predicted_relations) == len(sorted_original_relations) + + for relation, original_relation in zip(predicted_relations, sorted_original_relations): + assert relation["type"] == f"PRED-{original_relation.label}" + assert spanid2span[relation["arguments"][0]["target"]] == original_relation.head + assert spanid2span[relation["arguments"][1]["target"]] == original_relation.tail @pytest.fixture -def document_with_prediction_only(): +def document_with_multispan(): document = TextDocumentWithLabeledSpansAndBinaryRelations( - text="Harry lives in Berlin. He works at DFKI.", id="tmp" + text="Harry lives in Berlin, Germany. He works at DFKI.", id="tmp" ) - labeled_spans = [ + entities = [ LabeledSpan(start=0, end=5, label="PERSON"), LabeledSpan(start=15, end=21, label="LOCATION"), - LabeledSpan(start=35, end=39, label="ORGANIZATION"), + LabeledSpan(start=23, end=30, label="LOCATION"), + LabeledSpan(start=44, end=48, label="ORGANIZATION"), ] - - # add entities as predictions - for ent in labeled_spans: + for ent in entities: document.labeled_spans.predictions.append(ent) - binary_relations = [ - BinaryRelation(head=labeled_spans[0], tail=labeled_spans[1], label="lives_in"), - BinaryRelation(head=labeled_spans[0], tail=labeled_spans[2], label="works_at"), + relations = [ + BinaryRelation(head=entities[0], tail=entities[1], label="lives_in"), + BinaryRelation( + head=entities[1], tail=entities[2], label="parts_of_same" + ), # should be removed + BinaryRelation( + head=entities[1], tail=entities[3], label="parts_of_same" + ), # should be removed + BinaryRelation(head=entities[0], tail=entities[3], label="works_at"), + BinaryRelation( + head=entities[3], tail=entities[1], label="located_in" + ), # tail should be a new merged entity + BinaryRelation( + head=entities[3], tail=entities[2], label="located_in" + ), # tail should be a new merged entity ] # add relations as predictions - for rel in binary_relations: + for rel in relations: document.binary_relations.predictions.append(rel) return document -def test_save_prediction_only(tmp_path, document_with_prediction_only, document_processor): +def test_save_multispan(tmp_path, document_with_multispan, document_processor): path = str(tmp_path) serializer = BratSerializer( - path=path, document_processor=document_processor, prediction_label_prefix="PRED" + path=path, + document_processor=document_processor, + entity_layer="spans", + relation_layer="relations", ) - metadata = serializer(documents=[document_with_prediction_only]) + metadata = serializer(documents=[document_with_multispan]) path = metadata["path"] - res = read_annotation_file(os.path.join(path, f"{document_with_prediction_only.id}.ann")) + res = read_annotation_file(os.path.join(path, f"{document_with_multispan.id}.ann")) + """ - res in the following format: - {'spans': + { + 'spans': [ - {'id': 'T0', 'text': 'Harry', 'type': 'PRED-PERSON', 'locations': [{'start': 0, 'end': 5}]}, - {'id': 'T1', 'text': 'DFKI', 'type': 'PRED-ORGANIZATION', 'locations': [{'start': 35, 'end': 39}]}, - {'id': 'T2', 'text': 'Berlin', 'type': 'PRED-LOCATION', 'locations': [{'start': 15, 'end': 21}]} + {'id': 'T0', 'text': 'DFKI', 'type': 'ORGANIZATION', 'locations': [{'start': 44, 'end': 48}]}, + {'id': 'T1', 'text': 'Harry', 'type': 'PERSON', 'locations': [{'start': 0, 'end': 5}]}, + {'id': 'T2', 'text': 'Berlin Germany', 'type': 'LOCATION', + 'locations': [{'start': 15, 'end': 21}, {'start': 23, 'end': 30}]} ], 'relations': [ - {'id': 'R0', 'type': 'lives_in', 'arguments': [{'type': 'Arg1', 'target': 'T0'}, + {'id': 'R0', 'type': 'lives_in', 'arguments': [{'type': 'Arg1', 'target': 'T1'}, {'type': 'Arg2', 'target': 'T2'}]}, - {'id': 'R1', 'type': 'works_at', 'arguments': [{'type': 'Arg1', 'target': 'T0'}, - {'type': 'Arg2', 'target': 'T1'}]} + {'id': 'R1', 'type': 'located_in', 'arguments': [{'type': 'Arg1', 'target': 'T0'}, + {'type': 'Arg2', 'target': 'T2'}]}, + {'id': 'R2', 'type': 'works_at', 'arguments': [{'type': 'Arg1', 'target': 'T1'}, + {'type': 'Arg2', 'target': 'T0'}]} ] - } - """ + } + + """ spans = res["spans"] - original_spans = document_with_prediction_only.labeled_spans.predictions - assert len(spans) == len(original_spans) + assert len(spans) == 3 sorted_spans = sorted(spans, key=lambda x: x["locations"][0]["start"]) - sorted_original_spans = sorted(original_spans, key=lambda x: x.start) - span2spanid = dd() - for span, original_span in zip(sorted_spans, sorted_original_spans): - assert span["locations"][0]["start"] == original_span.start - assert span["locations"][0]["end"] == original_span.end - assert span["type"] == f"PRED-{original_span.label}" - assert ( - span["text"] - == document_with_prediction_only.text[original_span.start : original_span.end] - ) - span2spanid[original_span] = span["id"] + spanid2span = dd() + + span = sorted_spans[0] # verify first span + assert span["locations"][0]["start"] == 0 + assert span["locations"][0]["end"] == 5 + assert span["type"] == "PERSON" + assert span["text"] == "Harry" + spanid2span[span["id"]] = span + + span = sorted_spans[1] # verify second span (multispan) + assert span["locations"][0]["start"] == 15 + assert span["locations"][0]["end"] == 21 + assert span["locations"][1]["start"] == 23 + assert span["locations"][1]["end"] == 30 + assert span["type"] == "LOCATION" + assert span["text"] == "Berlin Germany" + spanid2span[span["id"]] = span relations = res["relations"] - original_relations = document_with_prediction_only.binary_relations.predictions - assert len(relations) == len(original_relations) + assert len(relations) == 3 sorted_relations = sorted(relations, key=lambda x: x["type"]) - sorted_original_relations = sorted(original_relations, key=lambda x: x.label) - for relation, original_relation in zip(sorted_relations, sorted_original_relations): - assert relation["type"] == original_relation.label - assert relation["arguments"][0]["target"] == span2spanid[original_relation.head] - assert relation["arguments"][1]["target"] == span2spanid[original_relation.tail] + relation = sorted_relations[0] # verify relation between first and second span + relation_type = relation["type"] + assert relation_type == "lives_in" + arg1 = spanid2span[relation["arguments"][0]["target"]]["text"] + arg2 = spanid2span[relation["arguments"][1]["target"]]["text"] + assert f"{arg1} {relation_type} {arg2}" == "Harry lives_in Berlin Germany"