diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index 702da989b..d7869e369 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -256,23 +256,168 @@ def task_encoding_without_target(taskmodule, document): return taskmodule.encode_input(document)[0] -def test_input_encoding(task_encoding_without_target, taskmodule): - assert task_encoding_without_target is not None - tokens = taskmodule.tokenizer.convert_ids_to_tokens( - task_encoding_without_target.inputs.input_ids +@pytest.fixture(params=[False, True]) +def taskmodule_with_reversed_relations(document, request) -> PointerNetworkTaskModuleForEnd2EndRE: + is_about_is_symmetric = request.param + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + relation_layer_name="relations", + exclude_labels_per_layer={"relations": ["no_relation"]}, + annotation_field_mapping={ + "entities": "labeled_spans", + "relations": "binary_relations", + }, + create_constraints=True, + tokenizer_kwargs={"strict_span_conversion": False}, + add_reversed_relations=True, + symmetric_relations=["is_about"] if is_about_is_symmetric else None, ) - if taskmodule.partition_layer_name is None: - assert asdict(task_encoding_without_target.inputs) == { - "input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2], - "attention_mask": [1] * 13, + + taskmodule.prepare(documents=[document]) + assert taskmodule.is_prepared + if is_about_is_symmetric: + assert taskmodule.prepared_attributes == { + "labels_per_layer": { + "entities": ["content", "person", "topic"], + "relations": ["is_about"], + } } - elif taskmodule.partition_layer_name == "sentences": - assert asdict(task_encoding_without_target.inputs) == { - "input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 2], - "attention_mask": [1] * 10, + else: + assert taskmodule.prepared_attributes == { + "labels_per_layer": { + "entities": ["content", "person", "topic"], + "relations": ["is_about", "is_about_reversed"], + } + } + + return taskmodule + + +def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, document): + task_encodings = taskmodule_with_reversed_relations.encode(document, encode_target=True) + assert len(task_encodings) == 1 + task_encoding = task_encodings[0] + assert task_encoding is not None + assert asdict(task_encoding.inputs) == { + "input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2], + "attention_mask": [1] * 13, + } + tokens = taskmodule_with_reversed_relations.tokenizer.convert_ids_to_tokens( + task_encoding.inputs.input_ids + ) + assert tokens == [ + "", + "This", + "Ġis", + "Ġa", + "Ġdummy", + "Ġtext", + "Ġabout", + "Ġnothing", + ".", + "ĠTrust", + "Ġme", + ".", + "", + ] + if not taskmodule_with_reversed_relations.symmetric_relations: + assert task_encoding.targets.labels == [ + 15, + 15, + 5, + 12, + 13, + 3, + 6, + 12, + 13, + 3, + 15, + 15, + 5, + 7, + 18, + 18, + 4, + 2, + 2, + 2, + 2, + 1, + ] + decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations( + task_encoding.targets + ) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + LabeledSpan(start=10, end=11, label="person", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ), + BinaryRelation( + head=LabeledSpan(start=7, end=8, label="topic", score=1.0), + tail=LabeledSpan(start=4, end=6, label="content", score=1.0), + label="is_about_reversed", + score=1.0, + ), + ], } else: - raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}") + assert task_encoding.targets.labels == [ + 14, + 14, + 5, + 11, + 12, + 3, + 6, + 11, + 12, + 3, + 14, + 14, + 5, + 6, + 17, + 17, + 4, + 2, + 2, + 2, + 2, + 1, + ] + decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations( + task_encoding.targets + ) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + LabeledSpan(start=10, end=11, label="person", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ), + BinaryRelation( + head=LabeledSpan(start=7, end=8, label="topic", score=1.0), + tail=LabeledSpan(start=4, end=6, label="content", score=1.0), + label="is_about", + score=1.0, + ), + ], + } @pytest.fixture()