diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index 083519383..f6e385134 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -470,13 +470,56 @@ def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, } -def test_encode_with_add_reversed_relations_already_exists(taskmodule_with_reversed_relations): +def test_encode_with_add_reversed_relations_already_exists(caplog): doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") doc.entities.append(LabeledSpan(start=10, end=20, label="content")) doc.entities.append(LabeledSpan(start=27, end=34, label="topic")) - rel = BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") + doc.relations.append( + BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") + ) + doc.relations.append( + BinaryRelation(head=doc.entities[1], tail=doc.entities[0], label="is_about") + ) - task_encodings = taskmodule_with_reversed_relations.encode(document, encode_target=True) + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + relation_layer_name="relations", + annotation_field_mapping={ + "entities": "labeled_spans", + "relations": "binary_relations", + }, + add_reversed_relations=True, + symmetric_relations=["is_about"], + ) + taskmodule.prepare(documents=[doc]) + + with caplog.at_level(logging.WARNING): + task_encodings = taskmodule.encode(doc, encode_target=True) + assert len(caplog.messages) == 0 + assert len(task_encodings) == 1 + task_encoding = task_encodings[0] + + decoded_annotations, statistics = taskmodule.decode_annotations(task_encoding.targets) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ), + BinaryRelation( + head=LabeledSpan(start=7, end=8, label="topic", score=1.0), + tail=LabeledSpan(start=4, end=6, label="content", score=1.0), + label="is_about", + score=1.0, + ), + ], + } @pytest.fixture()