diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index f6e385134..7ad6f15b0 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -522,6 +522,57 @@ def test_encode_with_add_reversed_relations_already_exists(caplog): } +def test_decode_with_add_reversed_relations(): + doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") + doc.entities.append(LabeledSpan(start=10, end=20, label="content")) + doc.entities.append(LabeledSpan(start=27, end=34, label="topic")) + doc.relations.append( + BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") + ) + + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + relation_layer_name="relations", + annotation_field_mapping={ + "entities": "labeled_spans", + "relations": "binary_relations", + }, + add_reversed_relations=True, + ) + taskmodule.prepare(documents=[doc]) + + task_encodings = taskmodule.encode(doc, encode_target=True) + assert len(task_encodings) == 1 + decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ), + BinaryRelation( + head=LabeledSpan(start=7, end=8, label="topic", score=1.0), + tail=LabeledSpan(start=4, end=6, label="content", score=1.0), + label="is_about_reversed", + score=1.0, + ), + ], + } + + task_outputs = [task_encoding.targets for task_encoding in task_encodings] + docs_with_predictions = taskmodule.decode(task_encodings, task_outputs) + assert len(docs_with_predictions) == 1 + doc_with_predictions: ExampleDocument = docs_with_predictions[0] + assert list(doc_with_predictions.entities.predictions) == list(doc_with_predictions.entities) + assert list(doc_with_predictions.relations.predictions) == list(doc_with_predictions.relations) + + @pytest.fixture() def target_encoding(taskmodule, task_encoding_without_target): return taskmodule.encode_target(task_encoding_without_target)