diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index 7feb1604c..48e7ea315 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -106,6 +106,7 @@ class PointerNetworkTaskModuleForEnd2EndRE( ], ): PREPARED_ATTRIBUTES = ["labels_per_layer"] + REVERSED_RELATION_LABEL_SUFFIX = "_reversed" def __init__( self, @@ -114,6 +115,8 @@ def __init__( document_type: str = "pytorch_ie.documents.TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions", tokenized_document_type: str = "pie_modules.documents.TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions", relation_layer_name: str = "binary_relations", + add_reversed_relations: bool = False, + symmetric_relations: Optional[List[str]] = None, none_label: str = "none", loop_dummy_relation_name: str = "loop", constrained_generation: bool = False, @@ -167,6 +170,8 @@ def __init__( self.span_layer_name = annotation_field_mapping_inv.get( relation_layer_target, relation_layer_target ) + self.add_reversed_relations = add_reversed_relations + self.symmetric_relations = set(symmetric_relations or []) self.none_label = none_label self.loop_dummy_relation_name = loop_dummy_relation_name self.constrained_generation = constrained_generation @@ -263,6 +268,18 @@ def _prefix_allowed_tokens_fn_with_maximum( # convert to a list return allowed_indices.tolist() + def add_reversed_relation_labels(self, relation_labels: Iterable[str]) -> Set[str]: + result = set(relation_labels) + for rel_label in set(relation_labels): + if rel_label not in self.symmetric_relations: + reversed_label = rel_label + self.REVERSED_RELATION_LABEL_SUFFIX + if reversed_label in result: + raise ValueError( + f"reversed relation label {reversed_label} already exists in relation layer labels" + ) + result.add(reversed_label) + return result + def _prepare(self, documents: Sequence[DocumentType]) -> None: # collect all labels labels: Dict[str, Set[str]] = {layer_name: set() for layer_name in self.layer_names} @@ -273,6 +290,11 @@ def _prepare(self, documents: Sequence[DocumentType]) -> None: ac.label for ac in doc[layer_name] if ac.label not in exclude_labels ) + if self.add_reversed_relations: + labels[self.relation_layer_name] = self.add_reversed_relation_labels( + relation_labels=labels[self.relation_layer_name] + ) + self.labels_per_layer = { # sort labels to ensure deterministic order layer_name: sorted(labels) @@ -443,8 +465,38 @@ def decode_relations( return encodings, dict(errors), current_encoding + def reverse_relation(self, relation: Annotation) -> BinaryRelation: + if isinstance(relation, BinaryRelation): + reversed_label = relation.label + if ( + reversed_label not in self.symmetric_relations + and reversed_label != self.none_label + ): + reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX + reversed_rel = relation.copy( + head=relation.tail, tail=relation.head, label=reversed_label + ) + return reversed_rel + else: + raise Exception(f"reversing of relations of type {type(relation)} is not supported") + + def unreverse_relation(self, relation: Annotation) -> BinaryRelation: + if isinstance(relation, BinaryRelation): + head, tail, label = relation.head, relation.tail, relation.label + # if the relation is symmetric, we sort head and tail to ensure consistent order + if relation.label in self.symmetric_relations: + head, tail = sorted([head, tail], key=lambda x: (x.start, x.end)) + # if the relation was reversed, we need to reconstruct the original label and swap head and tail + elif label.endswith(self.REVERSED_RELATION_LABEL_SUFFIX): + # reconstruct the original label and swap head and tail + label = label[: -len(self.REVERSED_RELATION_LABEL_SUFFIX)] + head, tail = tail, head + return relation.copy(head=head, tail=tail, label=label) + else: + raise Exception(f"un-reversing of relations of type {type(relation)} is not supported") + def encode_annotations( - self, layers: Dict[str, List[Annotation]], metadata: Optional[Dict[str, Any]] = None + self, layers: Dict[str, Iterable[Annotation]], metadata: Optional[Dict[str, Any]] = None ) -> TaskOutputType: if not set(layers.keys()) == set(self.layer_names): raise Exception(f"unexpected layers: {layers.keys()}. expected: {self.layer_names}") @@ -783,6 +835,13 @@ def encode_target(self, task_encoding: TaskEncodingType) -> Optional[TargetEncod layer_name: self.get_mapped_layer(document, layer_name=layer_name) for layer_name in self.layer_names } + + if self.add_reversed_relations: + # create a copy to avoid modifying the annotation layer in the document + relations = list(layers[self.relation_layer_name]) + reversed_relations = [self.reverse_relation(rel) for rel in relations] + layers[self.relation_layer_name] = relations + reversed_relations + result = self.encode_annotations( layers=layers, metadata={ @@ -856,4 +915,9 @@ def create_annotations_from_output( for layer_name in layers: annotations = self.get_mapped_layer(untokenized_document, layer_name=layer_name) for annotation in annotations: - yield layer_name, annotation.copy() + # handle relations that may be reversed + if layer_name == self.relation_layer_name and self.add_reversed_relations: + unreversed_relation = self.unreverse_relation(annotation) + yield layer_name, unreversed_relation + else: + yield layer_name, annotation.copy() diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index ea1195480..30c6cd6e7 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -194,6 +194,7 @@ def test_prepared_config(taskmodule, config): assert taskmodule._config() == { "taskmodule_type": "PointerNetworkTaskModuleForEnd2EndRE", "relation_layer_name": "relations", + "symmetric_relations": None, "none_label": "none", "loop_dummy_relation_name": "loop", "labels_per_layer": { @@ -208,6 +209,7 @@ def test_prepared_config(taskmodule, config): "tokenizer_init_kwargs": None, "tokenizer_kwargs": {"strict_span_conversion": False}, "partition_layer_name": None, + "add_reversed_relations": False, "annotation_field_mapping": { "entities": "labeled_spans", "relations": "binary_relations", @@ -221,6 +223,7 @@ def test_prepared_config(taskmodule, config): assert taskmodule._config() == { "taskmodule_type": "PointerNetworkTaskModuleForEnd2EndRE", "relation_layer_name": "relations", + "symmetric_relations": None, "none_label": "none", "loop_dummy_relation_name": "loop", "labels_per_layer": { @@ -235,6 +238,7 @@ def test_prepared_config(taskmodule, config): "tokenizer_init_kwargs": None, "tokenizer_kwargs": {"strict_span_conversion": False}, "partition_layer_name": "sentences", + "add_reversed_relations": False, "annotation_field_mapping": { "entities": "labeled_spans", "relations": "binary_relations", @@ -253,23 +257,320 @@ def task_encoding_without_target(taskmodule, document): return taskmodule.encode_input(document)[0] -def test_input_encoding(task_encoding_without_target, taskmodule): - assert task_encoding_without_target is not None - tokens = taskmodule.tokenizer.convert_ids_to_tokens( - task_encoding_without_target.inputs.input_ids +def test_add_reversed_relation_labels(): + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + symmetric_relations=["symmetric_relation"], ) - if taskmodule.partition_layer_name is None: - assert asdict(task_encoding_without_target.inputs) == { - "input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2], - "attention_mask": [1] * 13, + + labels = ["is_about", "symmetric_relation"] + labels_with_reversed = taskmodule.add_reversed_relation_labels(labels) + assert labels_with_reversed == {"is_about", "is_about_reversed", "symmetric_relation"} + + +def test_reverse_relation(): + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + symmetric_relations=["symmetric_relation"], + ) + + rel = BinaryRelation( + head=LabeledSpan(start=10, end=20, label="content"), + tail=LabeledSpan(start=27, end=34, label="topic"), + label="is_about", + ) + reversed_relation = taskmodule.reverse_relation(relation=rel) + assert reversed_relation == BinaryRelation( + head=LabeledSpan(start=27, end=34, label="topic", score=1.0), + tail=LabeledSpan(start=10, end=20, label="content", score=1.0), + label="is_about_reversed", + score=1.0, + ) + + sym_rel = BinaryRelation( + head=LabeledSpan(start=10, end=20, label="content"), + tail=LabeledSpan(start=27, end=34, label="topic"), + label="symmetric_relation", + ) + reversed_sym_rel = taskmodule.reverse_relation(relation=sym_rel) + assert reversed_sym_rel == BinaryRelation( + head=LabeledSpan(start=27, end=34, label="topic", score=1.0), + tail=LabeledSpan(start=10, end=20, label="content", score=1.0), + label="symmetric_relation", + score=1.0, + ) + + +def test_unreverse_relation(): + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + symmetric_relations=["symmetric_relation"], + ) + + # nothing should change because the relation is not reversed + rel = BinaryRelation( + head=LabeledSpan(start=10, end=20, label="content"), + tail=LabeledSpan(start=27, end=34, label="topic"), + label="is_about", + ) + same_rel = taskmodule.unreverse_relation(relation=rel) + assert same_rel == rel + + # the relation is reversed, so it should be un-reversed + reversed_rel = BinaryRelation( + head=LabeledSpan(start=10, end=20, label="content"), + tail=LabeledSpan(start=27, end=34, label="topic"), + label="is_about_reversed", + ) + unreversed_relation = taskmodule.unreverse_relation(relation=reversed_rel) + assert unreversed_relation == BinaryRelation( + head=LabeledSpan(start=27, end=34, label="topic", score=1.0), + tail=LabeledSpan(start=10, end=20, label="content", score=1.0), + label="is_about", + score=1.0, + ) + + # nothing should change because the relation is symmetric and already ordered (head < tail) + ordered_sym_rel = BinaryRelation( + head=LabeledSpan(start=10, end=20, label="content"), + tail=LabeledSpan(start=27, end=34, label="topic"), + label="symmetric_relation", + ) + unreversed_ordered_sym_rel = taskmodule.unreverse_relation(relation=ordered_sym_rel) + assert ordered_sym_rel == unreversed_ordered_sym_rel + + # the relation is symmetric and unordered (head > tail), so it should be un-reversed + unordered_sym_rel = BinaryRelation( + head=LabeledSpan(start=27, end=34, label="topic"), + tail=LabeledSpan(start=10, end=20, label="content"), + label="symmetric_relation", + ) + unreversed_unordered_sym_rel = taskmodule.unreverse_relation(relation=unordered_sym_rel) + assert unreversed_unordered_sym_rel == BinaryRelation( + head=LabeledSpan(start=10, end=20, label="content", score=1.0), + tail=LabeledSpan(start=27, end=34, label="topic", score=1.0), + label="symmetric_relation", + score=1.0, + ) + + +@pytest.fixture(params=[False, True]) +def taskmodule_with_reversed_relations(document, request) -> PointerNetworkTaskModuleForEnd2EndRE: + is_about_is_symmetric = request.param + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + relation_layer_name="relations", + exclude_labels_per_layer={"relations": ["no_relation"]}, + annotation_field_mapping={ + "entities": "labeled_spans", + "relations": "binary_relations", + }, + create_constraints=True, + tokenizer_kwargs={"strict_span_conversion": False}, + add_reversed_relations=True, + symmetric_relations=["is_about"] if is_about_is_symmetric else None, + ) + + taskmodule.prepare(documents=[document]) + assert taskmodule.is_prepared + if is_about_is_symmetric: + assert taskmodule.prepared_attributes == { + "labels_per_layer": { + "entities": ["content", "person", "topic"], + "relations": ["is_about"], + } } - elif taskmodule.partition_layer_name == "sentences": - assert asdict(task_encoding_without_target.inputs) == { - "input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 2], - "attention_mask": [1] * 10, + else: + assert taskmodule.prepared_attributes == { + "labels_per_layer": { + "entities": ["content", "person", "topic"], + "relations": ["is_about", "is_about_reversed"], + } + } + + return taskmodule + + +def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, document): + task_encodings = taskmodule_with_reversed_relations.encode(document, encode_target=True) + assert len(task_encodings) == 1 + task_encoding = task_encodings[0] + assert task_encoding is not None + assert asdict(task_encoding.inputs) == { + "input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2], + "attention_mask": [1] * 13, + } + tokens = taskmodule_with_reversed_relations.tokenizer.convert_ids_to_tokens( + task_encoding.inputs.input_ids + ) + assert tokens == [ + "", + "This", + "Ġis", + "Ġa", + "Ġdummy", + "Ġtext", + "Ġabout", + "Ġnothing", + ".", + "ĠTrust", + "Ġme", + ".", + "", + ] + if "is_about" in taskmodule_with_reversed_relations.symmetric_relations: + decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations( + task_encoding.targets + ) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + LabeledSpan(start=10, end=11, label="person", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ), + BinaryRelation( + head=LabeledSpan(start=7, end=8, label="topic", score=1.0), + tail=LabeledSpan(start=4, end=6, label="content", score=1.0), + label="is_about", + score=1.0, + ), + ], } else: - raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}") + decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations( + task_encoding.targets + ) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + LabeledSpan(start=10, end=11, label="person", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ), + BinaryRelation( + head=LabeledSpan(start=7, end=8, label="topic", score=1.0), + tail=LabeledSpan(start=4, end=6, label="content", score=1.0), + label="is_about_reversed", + score=1.0, + ), + ], + } + + +def test_encode_with_add_reversed_relations_already_exists(caplog): + doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") + doc.entities.append(LabeledSpan(start=10, end=20, label="content")) + doc.entities.append(LabeledSpan(start=27, end=34, label="topic")) + doc.relations.append( + BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") + ) + doc.relations.append( + BinaryRelation(head=doc.entities[1], tail=doc.entities[0], label="is_about") + ) + + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + relation_layer_name="relations", + annotation_field_mapping={ + "entities": "labeled_spans", + "relations": "binary_relations", + }, + add_reversed_relations=True, + symmetric_relations=["is_about"], + ) + taskmodule.prepare(documents=[doc]) + + with caplog.at_level(logging.WARNING): + task_encodings = taskmodule.encode(doc, encode_target=True) + assert len(caplog.messages) == 0 + assert len(task_encodings) == 1 + task_encoding = task_encodings[0] + + decoded_annotations, statistics = taskmodule.decode_annotations(task_encoding.targets) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ), + BinaryRelation( + head=LabeledSpan(start=7, end=8, label="topic", score=1.0), + tail=LabeledSpan(start=4, end=6, label="content", score=1.0), + label="is_about", + score=1.0, + ), + ], + } + + +def test_decode_with_add_reversed_relations(): + doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") + doc.entities.append(LabeledSpan(start=10, end=20, label="content")) + doc.entities.append(LabeledSpan(start=27, end=34, label="topic")) + doc.relations.append( + BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") + ) + + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + relation_layer_name="relations", + annotation_field_mapping={ + "entities": "labeled_spans", + "relations": "binary_relations", + }, + add_reversed_relations=True, + ) + taskmodule.prepare(documents=[doc]) + + task_encodings = taskmodule.encode(doc, encode_target=True) + assert len(task_encodings) == 1 + decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets) + assert decoded_annotations == { + "entities": [ + LabeledSpan(start=4, end=6, label="content", score=1.0), + LabeledSpan(start=7, end=8, label="topic", score=1.0), + ], + "relations": [ + BinaryRelation( + head=LabeledSpan(start=4, end=6, label="content", score=1.0), + tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), + label="is_about", + score=1.0, + ), + BinaryRelation( + head=LabeledSpan(start=7, end=8, label="topic", score=1.0), + tail=LabeledSpan(start=4, end=6, label="content", score=1.0), + label="is_about_reversed", + score=1.0, + ), + ], + } + + task_outputs = [task_encoding.targets for task_encoding in task_encodings] + docs_with_predictions = taskmodule.decode(task_encodings, task_outputs) + assert len(docs_with_predictions) == 1 + doc_with_predictions: ExampleDocument = docs_with_predictions[0] + assert set(doc_with_predictions.entities.predictions) == set(doc_with_predictions.entities) + assert set(doc_with_predictions.relations.predictions) == set(doc_with_predictions.relations) @pytest.fixture()