From 26970978775100f62dfb1be5a490604d0c21af5e Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 18:47:43 +0100 Subject: [PATCH 1/4] add test for duplicated / conflicting relations --- .../test_pointer_network_for_end2end_re.py | 138 +++++++++++++++++- 1 file changed, 132 insertions(+), 6 deletions(-) diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index 65a217240..7b952e48d 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -42,14 +42,15 @@ def config(config_str): return CONFIG_DICT[config_str] +@dataclass +class ExampleDocument(TextBasedDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") + sentences: AnnotationList[LabeledSpan] = annotation_field(target="text") + + @pytest.fixture(scope="module") def document(): - @dataclass - class ExampleDocument(TextBasedDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") - sentences: AnnotationList[LabeledSpan] = annotation_field(target="text") - doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") span1 = LabeledSpan(start=10, end=20, label="content") span2 = LabeledSpan(start=27, end=34, label="topic") @@ -317,6 +318,131 @@ def test_target_encoding(target_encoding, taskmodule): raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}") +def test_task_encoding_with_deduplicated_relations(caplog): + doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") + doc.entities.append(LabeledSpan(start=10, end=20, label="content")) + doc.entities.append(LabeledSpan(start=27, end=34, label="topic")) + doc.entities.append(LabeledSpan(start=42, end=44, label="person")) + assert doc.entities.resolve() == [ + ("content", "dummy text"), + ("topic", "nothing"), + ("person", "me"), + ] + # add the same relation twice (just use a different score, but that should not matter) + doc.relations.append( + BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") + ) + doc.relations.append( + BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about", score=0.9) + ) + assert doc.relations.resolve() == [ + ("is_about", (("content", "dummy text"), ("topic", "nothing"))), + ("is_about", (("content", "dummy text"), ("topic", "nothing"))), + ] + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + relation_layer_name="relations", + annotation_field_mapping={ + "entities": "labeled_spans", + "relations": "binary_relations", + }, + ) + taskmodule.prepare(documents=[doc]) + caplog.clear() + with caplog.at_level(logging.WARNING): + task_encodings = taskmodule.encode(doc, encode_target=True) + + assert caplog.messages == [ + ( + "encoding errors: {'correct': 2}, skipped annotations:\n" + "{\n" + ' "relations": [\n' + ' "BinaryRelation(' + "head=LabeledSpan(start=4, end=6, label='content', score=1.0), " + "tail=LabeledSpan(start=7, end=8, label='topic', score=1.0), " + "label='is_about', score=0.9" + ')"\n' + " ]\n" + "}" + ) + ] + + assert len(task_encodings) == 1 + decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + LabeledSpan(start=10, end=11, label="person", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ) + ], + } + + +def test_task_encoding_with_conflicting_relations(caplog): + doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") + doc.entities.append(LabeledSpan(start=10, end=20, label="content")) + doc.entities.append(LabeledSpan(start=27, end=34, label="topic")) + doc.entities.append(LabeledSpan(start=42, end=44, label="person")) + assert doc.entities.resolve() == [ + ("content", "dummy text"), + ("topic", "nothing"), + ("person", "me"), + ] + # add two relations with the same head and tail, but different labels + doc.relations.append( + BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") + ) + doc.relations.append( + BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="wrong_relation") + ) + assert doc.relations.resolve() == [ + ("is_about", (("content", "dummy text"), ("topic", "nothing"))), + ("wrong_relation", (("content", "dummy text"), ("topic", "nothing"))), + ] + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + relation_layer_name="relations", + annotation_field_mapping={ + "entities": "labeled_spans", + "relations": "binary_relations", + }, + ) + taskmodule.prepare(documents=[doc]) + caplog.clear() + with caplog.at_level(logging.WARNING): + task_encodings = taskmodule.encode(doc, encode_target=True) + assert caplog.messages == [ + "relation ('Ġdummy', 'Ġtext') -> ('Ġnothing',) already exists, but has another label: is_about (previous label: wrong_relation). Skipping.", + "encoding errors: {'correct': 2}, skipped annotations:\n{\n \"relations\": [\n \"BinaryRelation(head=LabeledSpan(start=4, end=6, label='content', score=1.0), tail=LabeledSpan(start=7, end=8, label='topic', score=1.0), label='wrong_relation', score=1.0)\"\n ]\n}", + ] + + assert len(task_encodings) == 1 + decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + LabeledSpan(start=10, end=11, label="person", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ) + ], + } + + @pytest.fixture() def task_encoding(task_encoding_without_target, target_encoding): task_encoding_without_target.targets = target_encoding From cf287c492b17dea33f468e9c9fc367f8eb1bcb11 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 18:57:21 +0100 Subject: [PATCH 2/4] check messages at the end --- .../test_pointer_network_for_end2end_re.py | 41 ++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index 7b952e48d..035ad0a6f 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -351,21 +351,7 @@ def test_task_encoding_with_deduplicated_relations(caplog): caplog.clear() with caplog.at_level(logging.WARNING): task_encodings = taskmodule.encode(doc, encode_target=True) - - assert caplog.messages == [ - ( - "encoding errors: {'correct': 2}, skipped annotations:\n" - "{\n" - ' "relations": [\n' - ' "BinaryRelation(' - "head=LabeledSpan(start=4, end=6, label='content', score=1.0), " - "tail=LabeledSpan(start=7, end=8, label='topic', score=1.0), " - "label='is_about', score=0.9" - ')"\n' - " ]\n" - "}" - ) - ] + messages = list(caplog.messages) assert len(task_encodings) == 1 decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets) @@ -385,6 +371,21 @@ def test_task_encoding_with_deduplicated_relations(caplog): ], } + assert messages == [ + ( + "encoding errors: {'correct': 2}, skipped annotations:\n" + "{\n" + ' "relations": [\n' + ' "BinaryRelation(' + "head=LabeledSpan(start=4, end=6, label='content', score=1.0), " + "tail=LabeledSpan(start=7, end=8, label='topic', score=1.0), " + "label='is_about', score=0.9" + ')"\n' + " ]\n" + "}" + ) + ] + def test_task_encoding_with_conflicting_relations(caplog): doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") @@ -419,10 +420,7 @@ def test_task_encoding_with_conflicting_relations(caplog): caplog.clear() with caplog.at_level(logging.WARNING): task_encodings = taskmodule.encode(doc, encode_target=True) - assert caplog.messages == [ - "relation ('Ġdummy', 'Ġtext') -> ('Ġnothing',) already exists, but has another label: is_about (previous label: wrong_relation). Skipping.", - "encoding errors: {'correct': 2}, skipped annotations:\n{\n \"relations\": [\n \"BinaryRelation(head=LabeledSpan(start=4, end=6, label='content', score=1.0), tail=LabeledSpan(start=7, end=8, label='topic', score=1.0), label='wrong_relation', score=1.0)\"\n ]\n}", - ] + messages = list(caplog.messages) assert len(task_encodings) == 1 decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets) @@ -442,6 +440,11 @@ def test_task_encoding_with_conflicting_relations(caplog): ], } + assert messages == [ + "relation ('Ġdummy', 'Ġtext') -> ('Ġnothing',) already exists, but has another label: is_about (previous label: wrong_relation). Skipping.", + "encoding errors: {'correct': 2}, skipped annotations:\n{\n \"relations\": [\n \"BinaryRelation(head=LabeledSpan(start=4, end=6, label='content', score=1.0), tail=LabeledSpan(start=7, end=8, label='topic', score=1.0), label='wrong_relation', score=1.0)\"\n ]\n}", + ] + @pytest.fixture() def task_encoding(task_encoding_without_target, target_encoding): From f9fcd9154200430df38246713aef7790853f272d Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 19:00:15 +0100 Subject: [PATCH 3/4] skip relations when another relation with same arguments was already encoded --- .../taskmodules/pointer_network_for_end2end_re.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index 2042f0ec4..ced418a92 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -454,11 +454,20 @@ def encode_annotations( # encode relations all_relation_arguments = set() + relation_arguments2label: Dict[Tuple[Annotation, ...], str] = dict() relation_encodings = dict() for rel in layers[self.relation_layer_name]: if not isinstance(rel, BinaryRelation): raise Exception(f"expected BinaryRelation, but got: {rel}") if rel.label in self.labels_per_layer[self.relation_layer_name]: + if (rel.head, rel.tail) in relation_arguments2label: + previous_label = relation_arguments2label[(rel.head, rel.tail)] + if previous_label != rel.label: + logger.warning( + f"relation {rel.head} -> {rel.tail} already exists, but has another label: " + f"{previous_label} (previous label: {rel.label}). Skipping." + ) + continue encoded_relation = self.relation_encoder_decoder.encode( annotation=rel, metadata=metadata ) @@ -466,6 +475,7 @@ def encode_annotations( raise Exception(f"failed to encode relation: {rel}") relation_encodings[rel] = encoded_relation all_relation_arguments.update([rel.head, rel.tail]) + relation_arguments2label[(rel.head, rel.tail)] = rel.label # encode spans that are not arguments of any relation no_relation_spans = [ From 8d4b1e31142295b0b0e9210c221fb626461e5f17 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 19:10:02 +0100 Subject: [PATCH 4/4] raise an exception to skip the whole task encoding --- .../pointer_network_for_end2end_re.py | 4 +-- .../test_pointer_network_for_end2end_re.py | 25 ++++--------------- 2 files changed, 7 insertions(+), 22 deletions(-) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index ced418a92..7feb1604c 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -463,9 +463,9 @@ def encode_annotations( if (rel.head, rel.tail) in relation_arguments2label: previous_label = relation_arguments2label[(rel.head, rel.tail)] if previous_label != rel.label: - logger.warning( + raise ValueError( f"relation {rel.head} -> {rel.tail} already exists, but has another label: " - f"{previous_label} (previous label: {rel.label}). Skipping." + f"{previous_label} (current label: {rel.label})." ) continue encoded_relation = self.relation_encoder_decoder.encode( diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index 035ad0a6f..ea1195480 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -418,31 +418,16 @@ def test_task_encoding_with_conflicting_relations(caplog): ) taskmodule.prepare(documents=[doc]) caplog.clear() - with caplog.at_level(logging.WARNING): + with caplog.at_level(logging.ERROR): task_encodings = taskmodule.encode(doc, encode_target=True) messages = list(caplog.messages) - assert len(task_encodings) == 1 - decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets) - assert decoded_annotations == { - "entities": [ - LabeledSpan(start=4, end=6, label="content", score=1.0), - LabeledSpan(start=7, end=8, label="topic", score=1.0), - LabeledSpan(start=10, end=11, label="person", score=1.0), - ], - "relations": [ - BinaryRelation( - head=LabeledSpan(start=4, end=6, label="content", score=1.0), - tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), - label="is_about", - score=1.0, - ) - ], - } + assert len(task_encodings) == 0 assert messages == [ - "relation ('Ġdummy', 'Ġtext') -> ('Ġnothing',) already exists, but has another label: is_about (previous label: wrong_relation). Skipping.", - "encoding errors: {'correct': 2}, skipped annotations:\n{\n \"relations\": [\n \"BinaryRelation(head=LabeledSpan(start=4, end=6, label='content', score=1.0), tail=LabeledSpan(start=7, end=8, label='topic', score=1.0), label='wrong_relation', score=1.0)\"\n ]\n}", + "failed to encode target, it will be skipped: " + "relation ('Ġdummy', 'Ġtext') -> ('Ġnothing',) already exists, but has " + "another label: is_about (current label: wrong_relation)." ]