diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index 2042f0ec4..1ae43bc98 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -106,6 +106,7 @@ class PointerNetworkTaskModuleForEnd2EndRE( ], ): PREPARED_ATTRIBUTES = ["labels_per_layer"] + REVERSED_RELATION_LABEL_SUFFIX = "_reversed" def __init__( self, @@ -114,6 +115,8 @@ def __init__( document_type: str = "pytorch_ie.documents.TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions", tokenized_document_type: str = "pie_modules.documents.TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions", relation_layer_name: str = "binary_relations", + add_reversed_relations: bool = False, + symmetric_relations: Optional[List[str]] = None, none_label: str = "none", loop_dummy_relation_name: str = "loop", constrained_generation: bool = False, @@ -167,6 +170,8 @@ def __init__( self.span_layer_name = annotation_field_mapping_inv.get( relation_layer_target, relation_layer_target ) + self.add_reversed_relations = add_reversed_relations + self.symmetric_relations = set(symmetric_relations or []) self.none_label = none_label self.loop_dummy_relation_name = loop_dummy_relation_name self.constrained_generation = constrained_generation @@ -454,6 +459,7 @@ def encode_annotations( # encode relations all_relation_arguments = set() + relation_arguments2label = dict() relation_encodings = dict() for rel in layers[self.relation_layer_name]: if not isinstance(rel, BinaryRelation): @@ -466,6 +472,28 @@ def encode_annotations( raise Exception(f"failed to encode relation: {rel}") relation_encodings[rel] = encoded_relation all_relation_arguments.update([rel.head, rel.tail]) + relation_arguments2label[(rel.head, rel.tail)] = rel.label + if self.add_reversed_relations: + reversed_label = rel.label + if reversed_label not in self.symmetric_relations: + reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX + if (rel.tail, rel.head) in relation_arguments2label: + previous_label = relation_arguments2label[(rel.tail, rel.head)] + logger.warning( + f"reversed relation already exists: {rel.tail} -> {rel.head} with " + f"label {previous_label} (reversed label: {reversed_label}). Skipping." + ) + continue + reversed_rel = BinaryRelation( + head=rel.tail, + tail=rel.head, + label=rel.label + self.REVERSED_RELATION_LABEL_SUFFIX, + ) + encoded_reversed_rel = self.relation_encoder_decoder.encode( + annotation=reversed_rel, metadata=metadata + ) + if encoded_reversed_rel is not None: + relation_encodings[reversed_rel] = encoded_reversed_rel # encode spans that are not arguments of any relation no_relation_spans = [ @@ -846,4 +874,24 @@ def create_annotations_from_output( for layer_name in layers: annotations = self.get_mapped_layer(untokenized_document, layer_name=layer_name) for annotation in annotations: - yield layer_name, annotation.copy() + # handle relations that may be reversed + if layer_name == self.relation_layer_name and self.add_reversed_relations: + if not isinstance(annotation, BinaryRelation): + raise Exception( + f"expected BinaryRelation when handling reversed relations, but got: {annotation}" + ) + head, tail = annotation.head, annotation.tail + label = annotation.label + # if the relation is symmetric, we sort head and tail to ensure consistent order + if annotation.label in self.symmetric_relations: + head, tail = sorted( + [annotation.head, annotation.tail], key=lambda x: (x.start, x.end) + ) + # if the relation was reversed, we need to reconstruct the original label and swap head and tail + elif annotation.label.endswith(self.REVERSED_RELATION_LABEL_SUFFIX): + # reconstruct the original label and swap head and tail + label = annotation.label[: -len(self.REVERSED_RELATION_LABEL_SUFFIX)] + head, tail = tail, head + yield layer_name, annotation.copy(head=head, tail=tail, label=label) + else: + yield layer_name, annotation.copy()