From 46017b371e5e6edad6d34ec6ed54ac6677225c7e Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 14:44:59 +0100 Subject: [PATCH 01/20] add parameters add_reversed_relations and symmetric_relations to PointerNetworkTaskModuleForEnd2EndRE --- .../pointer_network_for_end2end_re.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index 7feb1604c..dfccd8c80 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -106,6 +106,7 @@ class PointerNetworkTaskModuleForEnd2EndRE( ], ): PREPARED_ATTRIBUTES = ["labels_per_layer"] + REVERSED_RELATION_LABEL_SUFFIX = "_reversed" def __init__( self, @@ -114,6 +115,8 @@ def __init__( document_type: str = "pytorch_ie.documents.TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions", tokenized_document_type: str = "pie_modules.documents.TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions", relation_layer_name: str = "binary_relations", + add_reversed_relations: bool = False, + symmetric_relations: Optional[List[str]] = None, none_label: str = "none", loop_dummy_relation_name: str = "loop", constrained_generation: bool = False, @@ -167,6 +170,8 @@ def __init__( self.span_layer_name = annotation_field_mapping_inv.get( relation_layer_target, relation_layer_target ) + self.add_reversed_relations = add_reversed_relations + self.symmetric_relations = set(symmetric_relations or []) self.none_label = none_label self.loop_dummy_relation_name = loop_dummy_relation_name self.constrained_generation = constrained_generation @@ -476,6 +481,27 @@ def encode_annotations( relation_encodings[rel] = encoded_relation all_relation_arguments.update([rel.head, rel.tail]) relation_arguments2label[(rel.head, rel.tail)] = rel.label + if self.add_reversed_relations: + reversed_label = rel.label + if reversed_label not in self.symmetric_relations: + reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX + if (rel.tail, rel.head) in relation_arguments2label: + previous_label = relation_arguments2label[(rel.tail, rel.head)] + logger.warning( + f"reversed relation already exists: {rel.tail} -> {rel.head} with " + f"label {previous_label} (reversed label: {reversed_label}). Skipping." + ) + continue + reversed_rel = BinaryRelation( + head=rel.tail, + tail=rel.head, + label=rel.label + self.REVERSED_RELATION_LABEL_SUFFIX, + ) + encoded_reversed_rel = self.relation_encoder_decoder.encode( + annotation=reversed_rel, metadata=metadata + ) + if encoded_reversed_rel is not None: + relation_encodings[reversed_rel] = encoded_reversed_rel # encode spans that are not arguments of any relation no_relation_spans = [ @@ -856,4 +882,24 @@ def create_annotations_from_output( for layer_name in layers: annotations = self.get_mapped_layer(untokenized_document, layer_name=layer_name) for annotation in annotations: - yield layer_name, annotation.copy() + # handle relations that may be reversed + if layer_name == self.relation_layer_name and self.add_reversed_relations: + if not isinstance(annotation, BinaryRelation): + raise Exception( + f"expected BinaryRelation when handling reversed relations, but got: {annotation}" + ) + head, tail = annotation.head, annotation.tail + label = annotation.label + # if the relation is symmetric, we sort head and tail to ensure consistent order + if annotation.label in self.symmetric_relations: + head, tail = sorted( + [annotation.head, annotation.tail], key=lambda x: (x.start, x.end) + ) + # if the relation was reversed, we need to reconstruct the original label and swap head and tail + elif annotation.label.endswith(self.REVERSED_RELATION_LABEL_SUFFIX): + # reconstruct the original label and swap head and tail + label = annotation.label[: -len(self.REVERSED_RELATION_LABEL_SUFFIX)] + head, tail = tail, head + yield layer_name, annotation.copy(head=head, tail=tail, label=label) + else: + yield layer_name, annotation.copy() From 34be967874e5b4d65b1409a360b00c4c733ef1b6 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 14:52:34 +0100 Subject: [PATCH 02/20] fix test_prepared_config --- tests/taskmodules/test_pointer_network_for_end2end_re.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index ea1195480..ebf0fa446 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -194,6 +194,7 @@ def test_prepared_config(taskmodule, config): assert taskmodule._config() == { "taskmodule_type": "PointerNetworkTaskModuleForEnd2EndRE", "relation_layer_name": "relations", + "symmetric_relations": None, "none_label": "none", "loop_dummy_relation_name": "loop", "labels_per_layer": { @@ -208,6 +209,7 @@ def test_prepared_config(taskmodule, config): "tokenizer_init_kwargs": None, "tokenizer_kwargs": {"strict_span_conversion": False}, "partition_layer_name": None, + "add_reversed_relations": False, "annotation_field_mapping": { "entities": "labeled_spans", "relations": "binary_relations", @@ -221,6 +223,7 @@ def test_prepared_config(taskmodule, config): assert taskmodule._config() == { "taskmodule_type": "PointerNetworkTaskModuleForEnd2EndRE", "relation_layer_name": "relations", + "symmetric_relations": None, "none_label": "none", "loop_dummy_relation_name": "loop", "labels_per_layer": { @@ -235,6 +238,7 @@ def test_prepared_config(taskmodule, config): "tokenizer_init_kwargs": None, "tokenizer_kwargs": {"strict_span_conversion": False}, "partition_layer_name": "sentences", + "add_reversed_relations": False, "annotation_field_mapping": { "entities": "labeled_spans", "relations": "binary_relations", From 46a03712c68b6f4d30c74ae80d4e09ef8309fbcf Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 15:25:06 +0100 Subject: [PATCH 03/20] add reversed relation labels to labels_per_layer in _prepare() --- .../taskmodules/pointer_network_for_end2end_re.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index dfccd8c80..8ef66993a 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -278,6 +278,17 @@ def _prepare(self, documents: Sequence[DocumentType]) -> None: ac.label for ac in doc[layer_name] if ac.label not in exclude_labels ) + if self.add_reversed_relations: + for rel_label in set(labels[self.relation_layer_name]): + reversed_label = rel_label + if rel_label not in self.symmetric_relations: + reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX + if reversed_label in labels[self.relation_layer_name]: + raise ValueError( + f"reversed relation label {reversed_label} already exists in relation layer labels" + ) + labels[self.relation_layer_name].add(reversed_label) + self.labels_per_layer = { # sort labels to ensure deterministic order layer_name: sorted(labels) From d4930ce6b68393e72d22c5a29f58d821a4cf454e Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 15:41:20 +0100 Subject: [PATCH 04/20] fix _prepare --- .../taskmodules/pointer_network_for_end2end_re.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index 8ef66993a..908f10ec4 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -280,14 +280,13 @@ def _prepare(self, documents: Sequence[DocumentType]) -> None: if self.add_reversed_relations: for rel_label in set(labels[self.relation_layer_name]): - reversed_label = rel_label if rel_label not in self.symmetric_relations: - reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX - if reversed_label in labels[self.relation_layer_name]: - raise ValueError( - f"reversed relation label {reversed_label} already exists in relation layer labels" - ) - labels[self.relation_layer_name].add(reversed_label) + reversed_label = rel_label + self.REVERSED_RELATION_LABEL_SUFFIX + if reversed_label in labels[self.relation_layer_name]: + raise ValueError( + f"reversed relation label {reversed_label} already exists in relation layer labels" + ) + labels[self.relation_layer_name].add(reversed_label) self.labels_per_layer = { # sort labels to ensure deterministic order From b1358226a26c8d8b58ca4546576c76a0dfbde474 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 15:41:28 +0100 Subject: [PATCH 05/20] fix encode_annotations --- src/pie_modules/taskmodules/pointer_network_for_end2end_re.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index 908f10ec4..932b5dcf3 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -505,7 +505,7 @@ def encode_annotations( reversed_rel = BinaryRelation( head=rel.tail, tail=rel.head, - label=rel.label + self.REVERSED_RELATION_LABEL_SUFFIX, + label=reversed_label, ) encoded_reversed_rel = self.relation_encoder_decoder.encode( annotation=reversed_rel, metadata=metadata From bbd37e0cdbdbf112f36876f413a4b1b2a0fdb986 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 15:41:36 +0100 Subject: [PATCH 06/20] add test_encode_with_add_reversed_relations --- .../test_pointer_network_for_end2end_re.py | 171 ++++++++++++++++-- 1 file changed, 158 insertions(+), 13 deletions(-) diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index ebf0fa446..09559ce7d 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -257,23 +257,168 @@ def task_encoding_without_target(taskmodule, document): return taskmodule.encode_input(document)[0] -def test_input_encoding(task_encoding_without_target, taskmodule): - assert task_encoding_without_target is not None - tokens = taskmodule.tokenizer.convert_ids_to_tokens( - task_encoding_without_target.inputs.input_ids +@pytest.fixture(params=[False, True]) +def taskmodule_with_reversed_relations(document, request) -> PointerNetworkTaskModuleForEnd2EndRE: + is_about_is_symmetric = request.param + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + relation_layer_name="relations", + exclude_labels_per_layer={"relations": ["no_relation"]}, + annotation_field_mapping={ + "entities": "labeled_spans", + "relations": "binary_relations", + }, + create_constraints=True, + tokenizer_kwargs={"strict_span_conversion": False}, + add_reversed_relations=True, + symmetric_relations=["is_about"] if is_about_is_symmetric else None, ) - if taskmodule.partition_layer_name is None: - assert asdict(task_encoding_without_target.inputs) == { - "input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2], - "attention_mask": [1] * 13, + + taskmodule.prepare(documents=[document]) + assert taskmodule.is_prepared + if is_about_is_symmetric: + assert taskmodule.prepared_attributes == { + "labels_per_layer": { + "entities": ["content", "person", "topic"], + "relations": ["is_about"], + } } - elif taskmodule.partition_layer_name == "sentences": - assert asdict(task_encoding_without_target.inputs) == { - "input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 2], - "attention_mask": [1] * 10, + else: + assert taskmodule.prepared_attributes == { + "labels_per_layer": { + "entities": ["content", "person", "topic"], + "relations": ["is_about", "is_about_reversed"], + } + } + + return taskmodule + + +def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, document): + task_encodings = taskmodule_with_reversed_relations.encode(document, encode_target=True) + assert len(task_encodings) == 1 + task_encoding = task_encodings[0] + assert task_encoding is not None + assert asdict(task_encoding.inputs) == { + "input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2], + "attention_mask": [1] * 13, + } + tokens = taskmodule_with_reversed_relations.tokenizer.convert_ids_to_tokens( + task_encoding.inputs.input_ids + ) + assert tokens == [ + "", + "This", + "Ġis", + "Ġa", + "Ġdummy", + "Ġtext", + "Ġabout", + "Ġnothing", + ".", + "ĠTrust", + "Ġme", + ".", + "", + ] + if not taskmodule_with_reversed_relations.symmetric_relations: + assert task_encoding.targets.labels == [ + 15, + 15, + 5, + 12, + 13, + 3, + 6, + 12, + 13, + 3, + 15, + 15, + 5, + 7, + 18, + 18, + 4, + 2, + 2, + 2, + 2, + 1, + ] + decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations( + task_encoding.targets + ) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + LabeledSpan(start=10, end=11, label="person", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ), + BinaryRelation( + head=LabeledSpan(start=7, end=8, label="topic", score=1.0), + tail=LabeledSpan(start=4, end=6, label="content", score=1.0), + label="is_about_reversed", + score=1.0, + ), + ], } else: - raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}") + assert task_encoding.targets.labels == [ + 14, + 14, + 5, + 11, + 12, + 3, + 6, + 11, + 12, + 3, + 14, + 14, + 5, + 6, + 17, + 17, + 4, + 2, + 2, + 2, + 2, + 1, + ] + decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations( + task_encoding.targets + ) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + LabeledSpan(start=10, end=11, label="person", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ), + BinaryRelation( + head=LabeledSpan(start=7, end=8, label="topic", score=1.0), + tail=LabeledSpan(start=4, end=6, label="content", score=1.0), + label="is_about", + score=1.0, + ), + ], + } @pytest.fixture() From 9a07cb459f0869358eb6419b2090f1c745e95571 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 17:14:59 +0100 Subject: [PATCH 07/20] outsource reverse_relation() --- .../pointer_network_for_end2end_re.py | 50 +++++++------ .../test_pointer_network_for_end2end_re.py | 70 ++++++------------- 2 files changed, 50 insertions(+), 70 deletions(-) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index 932b5dcf3..e6393f6d2 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -458,6 +458,21 @@ def decode_relations( return encodings, dict(errors), current_encoding + def reverse_relation(self, relation: Annotation) -> BinaryRelation: + if isinstance(relation, BinaryRelation): + reversed_label = relation.label + if ( + reversed_label not in self.symmetric_relations + and reversed_label != self.none_label + ): + reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX + reversed_rel = relation.copy( + head=relation.tail, tail=relation.head, label=reversed_label + ) + return reversed_rel + else: + raise Exception(f"reversing of relations of type {type(relation)} is not supported") + def encode_annotations( self, layers: Dict[str, List[Annotation]], metadata: Optional[Dict[str, Any]] = None ) -> TaskOutputType: @@ -467,11 +482,15 @@ def encode_annotations( if self.labels_per_layer is None: raise Exception("labels_per_layer is not defined. Call prepare() first or pass it in.") + relations = list(layers[self.relation_layer_name]) + if self.add_reversed_relations: + relations.extend(self.reverse_relation(rel) for rel in relations) + # encode relations all_relation_arguments = set() relation_arguments2label: Dict[Tuple[Annotation, ...], str] = dict() relation_encodings = dict() - for rel in layers[self.relation_layer_name]: + for rel in relations: if not isinstance(rel, BinaryRelation): raise Exception(f"expected BinaryRelation, but got: {rel}") if rel.label in self.labels_per_layer[self.relation_layer_name]: @@ -488,30 +507,17 @@ def encode_annotations( ) if encoded_relation is None: raise Exception(f"failed to encode relation: {rel}") + if (rel.head, rel.tail) in relation_arguments2label: + previous_label = relation_arguments2label[(rel.head, rel.tail)] + if previous_label != rel.label: + logger.warning( + f"relation {rel.head} -> {rel.tail} already exists, but has another label: " + f"{previous_label} (previous label: {rel.label}). Skipping." + ) + continue relation_encodings[rel] = encoded_relation all_relation_arguments.update([rel.head, rel.tail]) relation_arguments2label[(rel.head, rel.tail)] = rel.label - if self.add_reversed_relations: - reversed_label = rel.label - if reversed_label not in self.symmetric_relations: - reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX - if (rel.tail, rel.head) in relation_arguments2label: - previous_label = relation_arguments2label[(rel.tail, rel.head)] - logger.warning( - f"reversed relation already exists: {rel.tail} -> {rel.head} with " - f"label {previous_label} (reversed label: {reversed_label}). Skipping." - ) - continue - reversed_rel = BinaryRelation( - head=rel.tail, - tail=rel.head, - label=reversed_label, - ) - encoded_reversed_rel = self.relation_encoder_decoder.encode( - annotation=reversed_rel, metadata=metadata - ) - if encoded_reversed_rel is not None: - relation_encodings[reversed_rel] = encoded_reversed_rel # encode spans that are not arguments of any relation no_relation_spans = [ diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index 09559ce7d..de0f1a659 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -257,6 +257,19 @@ def task_encoding_without_target(taskmodule, document): return taskmodule.encode_input(document)[0] +def test_reverse_relation(taskmodule, document): + assert document.relations[0].resolve() == ( + "is_about", + (("content", "dummy text"), ("topic", "nothing")), + ) + + reversed_relation = taskmodule.reverse_relation(relation=document.relations[0]) + assert reversed_relation.resolve() == ( + "is_about_reversed", + (("topic", "nothing"), ("content", "dummy text")), + ) + + @pytest.fixture(params=[False, True]) def taskmodule_with_reversed_relations(document, request) -> PointerNetworkTaskModuleForEnd2EndRE: is_about_is_symmetric = request.param @@ -322,30 +335,6 @@ def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, "", ] if not taskmodule_with_reversed_relations.symmetric_relations: - assert task_encoding.targets.labels == [ - 15, - 15, - 5, - 12, - 13, - 3, - 6, - 12, - 13, - 3, - 15, - 15, - 5, - 7, - 18, - 18, - 4, - 2, - 2, - 2, - 2, - 1, - ] decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations( task_encoding.targets ) @@ -371,30 +360,6 @@ def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, ], } else: - assert task_encoding.targets.labels == [ - 14, - 14, - 5, - 11, - 12, - 3, - 6, - 11, - 12, - 3, - 14, - 14, - 5, - 6, - 17, - 17, - 4, - 2, - 2, - 2, - 2, - 1, - ] decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations( task_encoding.targets ) @@ -421,6 +386,15 @@ def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, } +def test_encode_with_add_reversed_relations_already_exists(taskmodule_with_reversed_relations): + doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") + doc.entities.append(LabeledSpan(start=10, end=20, label="content")) + doc.entities.append(LabeledSpan(start=27, end=34, label="topic")) + rel = BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") + + task_encodings = taskmodule_with_reversed_relations.encode(document, encode_target=True) + + @pytest.fixture() def target_encoding(taskmodule, task_encoding_without_target): return taskmodule.encode_target(task_encoding_without_target) From e6a4d9c819db03eaf8ef0f3152cc39554ee22141 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 17:43:48 +0100 Subject: [PATCH 08/20] fix infinite extension in encode_annotations --- src/pie_modules/taskmodules/pointer_network_for_end2end_re.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index e6393f6d2..f25755ab0 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -484,7 +484,8 @@ def encode_annotations( relations = list(layers[self.relation_layer_name]) if self.add_reversed_relations: - relations.extend(self.reverse_relation(rel) for rel in relations) + reversed_relations = [self.reverse_relation(rel) for rel in relations] + relations.extend(reversed_relations) # encode relations all_relation_arguments = set() From aa7d04d57896cf19b14fbd551c9e84b3e3f9576a Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 17:44:19 +0100 Subject: [PATCH 09/20] outsource unreverse_relation() --- .../pointer_network_for_end2end_re.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index f25755ab0..57fe34a7c 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -473,6 +473,21 @@ def reverse_relation(self, relation: Annotation) -> BinaryRelation: else: raise Exception(f"reversing of relations of type {type(relation)} is not supported") + def unreverse_relation(self, relation: Annotation) -> BinaryRelation: + if isinstance(relation, BinaryRelation): + head, tail, label = relation.head, relation.tail, relation.label + # if the relation is symmetric, we sort head and tail to ensure consistent order + if relation.label in self.symmetric_relations: + head, tail = sorted([head, tail], key=lambda x: (x.start, x.end)) + # if the relation was reversed, we need to reconstruct the original label and swap head and tail + elif label.endswith(self.REVERSED_RELATION_LABEL_SUFFIX): + # reconstruct the original label and swap head and tail + label = label[: -len(self.REVERSED_RELATION_LABEL_SUFFIX)] + head, tail = tail, head + return relation.copy(head=head, tail=tail, label=label) + else: + raise Exception(f"un-reversing of relations of type {type(relation)} is not supported") + def encode_annotations( self, layers: Dict[str, List[Annotation]], metadata: Optional[Dict[str, Any]] = None ) -> TaskOutputType: @@ -901,22 +916,7 @@ def create_annotations_from_output( for annotation in annotations: # handle relations that may be reversed if layer_name == self.relation_layer_name and self.add_reversed_relations: - if not isinstance(annotation, BinaryRelation): - raise Exception( - f"expected BinaryRelation when handling reversed relations, but got: {annotation}" - ) - head, tail = annotation.head, annotation.tail - label = annotation.label - # if the relation is symmetric, we sort head and tail to ensure consistent order - if annotation.label in self.symmetric_relations: - head, tail = sorted( - [annotation.head, annotation.tail], key=lambda x: (x.start, x.end) - ) - # if the relation was reversed, we need to reconstruct the original label and swap head and tail - elif annotation.label.endswith(self.REVERSED_RELATION_LABEL_SUFFIX): - # reconstruct the original label and swap head and tail - label = annotation.label[: -len(self.REVERSED_RELATION_LABEL_SUFFIX)] - head, tail = tail, head - yield layer_name, annotation.copy(head=head, tail=tail, label=label) + unreversed_relation = self.unreverse_relation(annotation) + yield layer_name, unreversed_relation else: yield layer_name, annotation.copy() From 643c8014f219f6feba605e8abe13e6bce6e6df57 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 17:44:56 +0100 Subject: [PATCH 10/20] simplify test_reverse_relation() and add test_unreverse_relation() --- .../test_pointer_network_for_end2end_re.py | 89 +++++++++++++++++-- 1 file changed, 81 insertions(+), 8 deletions(-) diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index de0f1a659..8b9a412d5 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -257,16 +257,89 @@ def task_encoding_without_target(taskmodule, document): return taskmodule.encode_input(document)[0] -def test_reverse_relation(taskmodule, document): - assert document.relations[0].resolve() == ( - "is_about", - (("content", "dummy text"), ("topic", "nothing")), +def test_reverse_relation(): + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + symmetric_relations=["symmetric_relation"], + ) + + rel = BinaryRelation( + head=LabeledSpan(start=10, end=20, label="content"), + tail=LabeledSpan(start=27, end=34, label="topic"), + label="is_about", + ) + reversed_relation = taskmodule.reverse_relation(relation=rel) + assert reversed_relation == BinaryRelation( + head=LabeledSpan(start=27, end=34, label="topic", score=1.0), + tail=LabeledSpan(start=10, end=20, label="content", score=1.0), + label="is_about_reversed", + score=1.0, + ) + + sym_rel = BinaryRelation( + head=LabeledSpan(start=10, end=20, label="content"), + tail=LabeledSpan(start=27, end=34, label="topic"), + label="symmetric_relation", + ) + reversed_sym_rel = taskmodule.reverse_relation(relation=sym_rel) + assert reversed_sym_rel == BinaryRelation( + head=LabeledSpan(start=27, end=34, label="topic", score=1.0), + tail=LabeledSpan(start=10, end=20, label="content", score=1.0), + label="symmetric_relation", + score=1.0, + ) + + +def test_unreverse_relation(): + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + symmetric_relations=["symmetric_relation"], + ) + + # nothing should change because the relation is not reversed + rel = BinaryRelation( + head=LabeledSpan(start=10, end=20, label="content"), + tail=LabeledSpan(start=27, end=34, label="topic"), + label="is_about", + ) + same_rel = taskmodule.unreverse_relation(relation=rel) + assert same_rel == rel + + # the relation is reversed, so it should be un-reversed + reversed_rel = BinaryRelation( + head=LabeledSpan(start=10, end=20, label="content"), + tail=LabeledSpan(start=27, end=34, label="topic"), + label="is_about_reversed", + ) + unreversed_relation = taskmodule.unreverse_relation(relation=reversed_rel) + assert unreversed_relation == BinaryRelation( + head=LabeledSpan(start=27, end=34, label="topic", score=1.0), + tail=LabeledSpan(start=10, end=20, label="content", score=1.0), + label="is_about", + score=1.0, ) - reversed_relation = taskmodule.reverse_relation(relation=document.relations[0]) - assert reversed_relation.resolve() == ( - "is_about_reversed", - (("topic", "nothing"), ("content", "dummy text")), + # nothing should change because the relation is symmetric and already ordered (head < tail) + ordered_sym_rel = BinaryRelation( + head=LabeledSpan(start=10, end=20, label="content"), + tail=LabeledSpan(start=27, end=34, label="topic"), + label="symmetric_relation", + ) + unreversed_ordered_sym_rel = taskmodule.unreverse_relation(relation=ordered_sym_rel) + assert ordered_sym_rel == unreversed_ordered_sym_rel + + # the relation is symmetric and unordered (head > tail), so it should be un-reversed + unordered_sym_rel = BinaryRelation( + head=LabeledSpan(start=27, end=34, label="topic"), + tail=LabeledSpan(start=10, end=20, label="content"), + label="symmetric_relation", + ) + unreversed_unordered_sym_rel = taskmodule.unreverse_relation(relation=unordered_sym_rel) + assert unreversed_unordered_sym_rel == BinaryRelation( + head=LabeledSpan(start=10, end=20, label="content", score=1.0), + tail=LabeledSpan(start=27, end=34, label="topic", score=1.0), + label="symmetric_relation", + score=1.0, ) From a88d435f820bf7d352a171904ab6326fd0b6f569 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 17:45:52 +0100 Subject: [PATCH 11/20] clarify test_encode_with_add_reversed_relations() --- tests/taskmodules/test_pointer_network_for_end2end_re.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index 8b9a412d5..e33b06c23 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -407,7 +407,7 @@ def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, ".", "", ] - if not taskmodule_with_reversed_relations.symmetric_relations: + if "is_about" in taskmodule_with_reversed_relations.symmetric_relations: decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations( task_encoding.targets ) @@ -427,7 +427,7 @@ def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, BinaryRelation( head=LabeledSpan(start=7, end=8, label="topic", score=1.0), tail=LabeledSpan(start=4, end=6, label="content", score=1.0), - label="is_about_reversed", + label="is_about", score=1.0, ), ], @@ -452,7 +452,7 @@ def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, BinaryRelation( head=LabeledSpan(start=7, end=8, label="topic", score=1.0), tail=LabeledSpan(start=4, end=6, label="content", score=1.0), - label="is_about", + label="is_about_reversed", score=1.0, ), ], From 03814dba7851585f2a4a29d7358340989ef7afcf Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 17:56:31 +0100 Subject: [PATCH 12/20] outsource add_reversed_relation_labels() --- .../pointer_network_for_end2end_re.py | 23 ++++++++++++------- .../test_pointer_network_for_end2end_re.py | 11 +++++++++ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index 57fe34a7c..de38348c9 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -268,6 +268,18 @@ def _prefix_allowed_tokens_fn_with_maximum( # convert to a list return allowed_indices.tolist() + def add_reversed_relation_labels(self, relation_labels: Iterable[str]) -> Set[str]: + result = set(relation_labels) + for rel_label in set(relation_labels): + if rel_label not in self.symmetric_relations: + reversed_label = rel_label + self.REVERSED_RELATION_LABEL_SUFFIX + if reversed_label in result: + raise ValueError( + f"reversed relation label {reversed_label} already exists in relation layer labels" + ) + result.add(reversed_label) + return result + def _prepare(self, documents: Sequence[DocumentType]) -> None: # collect all labels labels: Dict[str, Set[str]] = {layer_name: set() for layer_name in self.layer_names} @@ -279,14 +291,9 @@ def _prepare(self, documents: Sequence[DocumentType]) -> None: ) if self.add_reversed_relations: - for rel_label in set(labels[self.relation_layer_name]): - if rel_label not in self.symmetric_relations: - reversed_label = rel_label + self.REVERSED_RELATION_LABEL_SUFFIX - if reversed_label in labels[self.relation_layer_name]: - raise ValueError( - f"reversed relation label {reversed_label} already exists in relation layer labels" - ) - labels[self.relation_layer_name].add(reversed_label) + labels[self.relation_layer_name] = self.add_reversed_relation_labels( + relation_labels=labels[self.relation_layer_name] + ) self.labels_per_layer = { # sort labels to ensure deterministic order diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index e33b06c23..083519383 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -257,6 +257,17 @@ def task_encoding_without_target(taskmodule, document): return taskmodule.encode_input(document)[0] +def test_add_reversed_relation_labels(): + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + symmetric_relations=["symmetric_relation"], + ) + + labels = ["is_about", "symmetric_relation"] + labels_with_reversed = taskmodule.add_reversed_relation_labels(labels) + assert labels_with_reversed == {"is_about", "is_about_reversed", "symmetric_relation"} + + def test_reverse_relation(): taskmodule = PointerNetworkTaskModuleForEnd2EndRE( tokenizer_name_or_path="facebook/bart-base", From cac6dacaf16402ee499ffa631356e8024bd69914 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 18:06:25 +0100 Subject: [PATCH 13/20] move adding reversed relations to encode_target() --- .../taskmodules/pointer_network_for_end2end_re.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index de38348c9..afb3f73eb 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -504,16 +504,11 @@ def encode_annotations( if self.labels_per_layer is None: raise Exception("labels_per_layer is not defined. Call prepare() first or pass it in.") - relations = list(layers[self.relation_layer_name]) - if self.add_reversed_relations: - reversed_relations = [self.reverse_relation(rel) for rel in relations] - relations.extend(reversed_relations) - # encode relations all_relation_arguments = set() relation_arguments2label: Dict[Tuple[Annotation, ...], str] = dict() relation_encodings = dict() - for rel in relations: + for rel in layers[self.relation_layer_name]: if not isinstance(rel, BinaryRelation): raise Exception(f"expected BinaryRelation, but got: {rel}") if rel.label in self.labels_per_layer[self.relation_layer_name]: @@ -848,6 +843,13 @@ def encode_target(self, task_encoding: TaskEncodingType) -> Optional[TargetEncod layer_name: self.get_mapped_layer(document, layer_name=layer_name) for layer_name in self.layer_names } + + if self.add_reversed_relations: + reversed_relations = [ + self.reverse_relation(rel) for rel in layers[self.relation_layer_name] + ] + layers[self.relation_layer_name].extend(reversed_relations) + result = self.encode_annotations( layers=layers, metadata={ From 7dd49bc274531e8f0bfec9377b7bfc6c50a76f7d Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 19:18:35 +0100 Subject: [PATCH 14/20] fix merge issue --- .../taskmodules/pointer_network_for_end2end_re.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index afb3f73eb..ce2176795 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -525,14 +525,6 @@ def encode_annotations( ) if encoded_relation is None: raise Exception(f"failed to encode relation: {rel}") - if (rel.head, rel.tail) in relation_arguments2label: - previous_label = relation_arguments2label[(rel.head, rel.tail)] - if previous_label != rel.label: - logger.warning( - f"relation {rel.head} -> {rel.tail} already exists, but has another label: " - f"{previous_label} (previous label: {rel.label}). Skipping." - ) - continue relation_encodings[rel] = encoded_relation all_relation_arguments.update([rel.head, rel.tail]) relation_arguments2label[(rel.head, rel.tail)] = rel.label From 7e448ba091caa5313fba24a94aa49951ad57f6db Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 19:42:03 +0100 Subject: [PATCH 15/20] improve typing --- src/pie_modules/taskmodules/pointer_network_for_end2end_re.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index ce2176795..2ad037c4b 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -496,7 +496,7 @@ def unreverse_relation(self, relation: Annotation) -> BinaryRelation: raise Exception(f"un-reversing of relations of type {type(relation)} is not supported") def encode_annotations( - self, layers: Dict[str, List[Annotation]], metadata: Optional[Dict[str, Any]] = None + self, layers: Dict[str, Iterable[Annotation]], metadata: Optional[Dict[str, Any]] = None ) -> TaskOutputType: if not set(layers.keys()) == set(self.layer_names): raise Exception(f"unexpected layers: {layers.keys()}. expected: {self.layer_names}") From b798844710761b17d7a04c0700dfcbea1f537882 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 19:42:34 +0100 Subject: [PATCH 16/20] copy annotation list before adding reversed --- .../taskmodules/pointer_network_for_end2end_re.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index 2ad037c4b..48e7ea315 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -837,10 +837,10 @@ def encode_target(self, task_encoding: TaskEncodingType) -> Optional[TargetEncod } if self.add_reversed_relations: - reversed_relations = [ - self.reverse_relation(rel) for rel in layers[self.relation_layer_name] - ] - layers[self.relation_layer_name].extend(reversed_relations) + # create a copy to avoid modifying the annotation layer in the document + relations = list(layers[self.relation_layer_name]) + reversed_relations = [self.reverse_relation(rel) for rel in relations] + layers[self.relation_layer_name] = relations + reversed_relations result = self.encode_annotations( layers=layers, From ac92b42c3206467f3b936248082ce01155fef05a Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 19:42:47 +0100 Subject: [PATCH 17/20] complete test_encode_with_add_reversed_relations_already_exists --- .../test_pointer_network_for_end2end_re.py | 49 +++++++++++++++++-- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index 083519383..f6e385134 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -470,13 +470,56 @@ def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, } -def test_encode_with_add_reversed_relations_already_exists(taskmodule_with_reversed_relations): +def test_encode_with_add_reversed_relations_already_exists(caplog): doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") doc.entities.append(LabeledSpan(start=10, end=20, label="content")) doc.entities.append(LabeledSpan(start=27, end=34, label="topic")) - rel = BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") + doc.relations.append( + BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") + ) + doc.relations.append( + BinaryRelation(head=doc.entities[1], tail=doc.entities[0], label="is_about") + ) - task_encodings = taskmodule_with_reversed_relations.encode(document, encode_target=True) + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + relation_layer_name="relations", + annotation_field_mapping={ + "entities": "labeled_spans", + "relations": "binary_relations", + }, + add_reversed_relations=True, + symmetric_relations=["is_about"], + ) + taskmodule.prepare(documents=[doc]) + + with caplog.at_level(logging.WARNING): + task_encodings = taskmodule.encode(doc, encode_target=True) + assert len(caplog.messages) == 0 + assert len(task_encodings) == 1 + task_encoding = task_encodings[0] + + decoded_annotations, statistics = taskmodule.decode_annotations(task_encoding.targets) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ), + BinaryRelation( + head=LabeledSpan(start=7, end=8, label="topic", score=1.0), + tail=LabeledSpan(start=4, end=6, label="content", score=1.0), + label="is_about", + score=1.0, + ), + ], + } @pytest.fixture() From 2a9b2d69a8f3f5358c6e7683c761a74646604160 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 20:03:06 +0100 Subject: [PATCH 18/20] add test_decode_with_add_reversed_relations --- .../test_pointer_network_for_end2end_re.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index f6e385134..7ad6f15b0 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -522,6 +522,57 @@ def test_encode_with_add_reversed_relations_already_exists(caplog): } +def test_decode_with_add_reversed_relations(): + doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") + doc.entities.append(LabeledSpan(start=10, end=20, label="content")) + doc.entities.append(LabeledSpan(start=27, end=34, label="topic")) + doc.relations.append( + BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") + ) + + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + relation_layer_name="relations", + annotation_field_mapping={ + "entities": "labeled_spans", + "relations": "binary_relations", + }, + add_reversed_relations=True, + ) + taskmodule.prepare(documents=[doc]) + + task_encodings = taskmodule.encode(doc, encode_target=True) + assert len(task_encodings) == 1 + decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ), + BinaryRelation( + head=LabeledSpan(start=7, end=8, label="topic", score=1.0), + tail=LabeledSpan(start=4, end=6, label="content", score=1.0), + label="is_about_reversed", + score=1.0, + ), + ], + } + + task_outputs = [task_encoding.targets for task_encoding in task_encodings] + docs_with_predictions = taskmodule.decode(task_encodings, task_outputs) + assert len(docs_with_predictions) == 1 + doc_with_predictions: ExampleDocument = docs_with_predictions[0] + assert list(doc_with_predictions.entities.predictions) == list(doc_with_predictions.entities) + assert list(doc_with_predictions.relations.predictions) == list(doc_with_predictions.relations) + + @pytest.fixture() def target_encoding(taskmodule, task_encoding_without_target): return taskmodule.encode_target(task_encoding_without_target) From 69e289bf35e8ab8820868c9f7b5b45abdba0fa0d Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 21:55:52 +0100 Subject: [PATCH 19/20] use deduplicate_annotations() in test --- poetry.lock | 28 +++++++++++-------- pyproject.toml | 4 ++- .../test_pointer_network_for_end2end_re.py | 2 +- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/poetry.lock b/poetry.lock index e63a68fd9..25febe387 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2158,23 +2158,27 @@ files = [ [[package]] name = "pytorch-ie" -version = "0.31.2" +version = "0.31.3" description = "State-of-the-art Information Extraction in PyTorch" optional = false -python-versions = "<4.0,>=3.9" -files = [ - {file = "pytorch_ie-0.31.2-py3-none-any.whl", hash = "sha256:fef91fb3d4dff84b0b6fd973d5bf7e5be51e0e01393195a01d7d824688c8cb3e"}, - {file = "pytorch_ie-0.31.2.tar.gz", hash = "sha256:cd9683ef4ba0191854ff1843f22f431f4c38c8745962ee55ba5b5c52f27afd7c"}, -] +python-versions = "^3.9" +files = [] +develop = false [package.dependencies] -absl-py = ">=1.0.0,<2.0.0" +absl-py = "^1.0.0" fsspec = "<2023.9.0" -pandas = ">=2.0.0,<3.0.0" -pytorch-lightning = ">=2,<3" +pandas = "^2.0.0" +pytorch-lightning = "^2" torch = ">=1.10" -torchmetrics = ">=1,<2" -transformers = ">=4.18,<5.0" +torchmetrics = "^1" +transformers = "^4.18" + +[package.source] +type = "git" +url = "https://github.com/ArneBinder/pytorch-ie" +reference = "document/deduplicate_annotations" +resolved_reference = "57bb34386a2ec9922ea7c8e7f36a0a199b02848e" [[package]] name = "pytorch-lightning" @@ -3439,4 +3443,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "9edc2e1c448159e3f55c1d7fb1c6fa1d2baa9a11b2ef6e8aa80d5f551789ecac" +content-hash = "ab8de51bf6d389468d923b5d391f339c6e7f383396565cf5db925039de957132" diff --git a/pyproject.toml b/pyproject.toml index 69593d7ab..d784a1584 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,9 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.9" -pytorch-ie = ">=0.31.2,<0.32.0" +#pytorch-ie = ">=0.31.4,<0.32.0" +# install from branch from https://github.com/ArneBinder/pytorch-ie/pull/436 +pytorch-ie = { git = "https://github.com/ArneBinder/pytorch-ie", branch = "document/deduplicate_annotations" } pytorch-lightning = "^2.1.0" torchmetrics = "^1" # >=4.35 because of BartModelWithDecoderPositionIds, <4.37 because of generation config diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index 7ad6f15b0..e49213109 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -568,7 +568,7 @@ def test_decode_with_add_reversed_relations(): task_outputs = [task_encoding.targets for task_encoding in task_encodings] docs_with_predictions = taskmodule.decode(task_encodings, task_outputs) assert len(docs_with_predictions) == 1 - doc_with_predictions: ExampleDocument = docs_with_predictions[0] + doc_with_predictions: ExampleDocument = docs_with_predictions[0].deduplicate_annotations() assert list(doc_with_predictions.entities.predictions) == list(doc_with_predictions.entities) assert list(doc_with_predictions.relations.predictions) == list(doc_with_predictions.relations) From 2d6d1d0f49c2275f8836ce68cb1d7aa008177734 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 21:59:34 +0100 Subject: [PATCH 20/20] revert: use deduplicate_annotations() in test --- poetry.lock | 28 ++++++++----------- pyproject.toml | 4 +-- .../test_pointer_network_for_end2end_re.py | 6 ++-- 3 files changed, 16 insertions(+), 22 deletions(-) diff --git a/poetry.lock b/poetry.lock index 25febe387..e63a68fd9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2158,27 +2158,23 @@ files = [ [[package]] name = "pytorch-ie" -version = "0.31.3" +version = "0.31.2" description = "State-of-the-art Information Extraction in PyTorch" optional = false -python-versions = "^3.9" -files = [] -develop = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "pytorch_ie-0.31.2-py3-none-any.whl", hash = "sha256:fef91fb3d4dff84b0b6fd973d5bf7e5be51e0e01393195a01d7d824688c8cb3e"}, + {file = "pytorch_ie-0.31.2.tar.gz", hash = "sha256:cd9683ef4ba0191854ff1843f22f431f4c38c8745962ee55ba5b5c52f27afd7c"}, +] [package.dependencies] -absl-py = "^1.0.0" +absl-py = ">=1.0.0,<2.0.0" fsspec = "<2023.9.0" -pandas = "^2.0.0" -pytorch-lightning = "^2" +pandas = ">=2.0.0,<3.0.0" +pytorch-lightning = ">=2,<3" torch = ">=1.10" -torchmetrics = "^1" -transformers = "^4.18" - -[package.source] -type = "git" -url = "https://github.com/ArneBinder/pytorch-ie" -reference = "document/deduplicate_annotations" -resolved_reference = "57bb34386a2ec9922ea7c8e7f36a0a199b02848e" +torchmetrics = ">=1,<2" +transformers = ">=4.18,<5.0" [[package]] name = "pytorch-lightning" @@ -3443,4 +3439,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "ab8de51bf6d389468d923b5d391f339c6e7f383396565cf5db925039de957132" +content-hash = "9edc2e1c448159e3f55c1d7fb1c6fa1d2baa9a11b2ef6e8aa80d5f551789ecac" diff --git a/pyproject.toml b/pyproject.toml index d784a1584..69593d7ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.9" -#pytorch-ie = ">=0.31.4,<0.32.0" -# install from branch from https://github.com/ArneBinder/pytorch-ie/pull/436 -pytorch-ie = { git = "https://github.com/ArneBinder/pytorch-ie", branch = "document/deduplicate_annotations" } +pytorch-ie = ">=0.31.2,<0.32.0" pytorch-lightning = "^2.1.0" torchmetrics = "^1" # >=4.35 because of BartModelWithDecoderPositionIds, <4.37 because of generation config diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index e49213109..30c6cd6e7 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -568,9 +568,9 @@ def test_decode_with_add_reversed_relations(): task_outputs = [task_encoding.targets for task_encoding in task_encodings] docs_with_predictions = taskmodule.decode(task_encodings, task_outputs) assert len(docs_with_predictions) == 1 - doc_with_predictions: ExampleDocument = docs_with_predictions[0].deduplicate_annotations() - assert list(doc_with_predictions.entities.predictions) == list(doc_with_predictions.entities) - assert list(doc_with_predictions.relations.predictions) == list(doc_with_predictions.relations) + doc_with_predictions: ExampleDocument = docs_with_predictions[0] + assert set(doc_with_predictions.entities.predictions) == set(doc_with_predictions.entities) + assert set(doc_with_predictions.relations.predictions) == set(doc_with_predictions.relations) @pytest.fixture()