diff --git a/src/pie_modules/document/processing/relation_argument_sorter.py b/src/pie_modules/document/processing/relation_argument_sorter.py index a6e3ad99b..b5d9aac87 100644 --- a/src/pie_modules/document/processing/relation_argument_sorter.py +++ b/src/pie_modules/document/processing/relation_argument_sorter.py @@ -6,6 +6,8 @@ from pytorch_ie.annotations import BinaryRelation, LabeledSpan from pytorch_ie.core import Annotation, AnnotationList, Document +from pie_modules.annotations import LabeledMultiSpan + logger = logging.getLogger(__name__) @@ -21,6 +23,20 @@ def get_relation_args(relation: Annotation) -> tuple[Annotation, ...]: ) +def sort_annotations(annotations: tuple[Annotation, ...]) -> tuple[Annotation, ...]: + if len(annotations) <= 1: + return annotations + if all(isinstance(ann, LabeledSpan) for ann in annotations): + return tuple(sorted(annotations, key=lambda ann: (ann.start, ann.end, ann.label))) + elif all(isinstance(ann, LabeledMultiSpan) for ann in annotations): + return tuple(sorted(annotations, key=lambda ann: (ann.slices, ann.label))) + else: + raise TypeError( + f"annotations {annotations} have unknown types [{set(type(ann) for ann in annotations)}], " + f"cannot sort them" + ) + + def construct_relation_with_new_args( relation: Annotation, new_args: tuple[Annotation, ...] ) -> BinaryRelation: @@ -38,10 +54,6 @@ def construct_relation_with_new_args( ) -def has_dependent_layers(document: D, layer: str) -> bool: - return layer not in document._annotation_graph["_artificial_root"] - - class RelationArgumentSorter: """Sorts the arguments of the relations in the given relation layer. The sorting is done by the start and end positions of the arguments. The relations with the same sorted arguments are @@ -50,47 +62,44 @@ class RelationArgumentSorter: Args: relation_layer: the name of the relation layer label_whitelist: if not None, only the relations with the label in the whitelist are sorted - inplace: if True, the sorting is done in place, otherwise the document is copied and the sorting is done - on the copy + verbose: if True, log warnings for relations with sorted arguments that are already present """ def __init__( - self, relation_layer: str, label_whitelist: list[str] | None = None, inplace: bool = True + self, + relation_layer: str, + label_whitelist: list[str] | None = None, + verbose: bool = True, ): self.relation_layer = relation_layer self.label_whitelist = label_whitelist - self.inplace = inplace + self.verbose = verbose def __call__(self, doc: D) -> D: - if not self.inplace: - doc = doc.copy() - rel_layer: AnnotationList[BinaryRelation] = doc[self.relation_layer] args2relations: dict[tuple[LabeledSpan, ...], BinaryRelation] = { get_relation_args(rel): rel for rel in rel_layer } - # assert that no other layers depend on the relation layer - if has_dependent_layers(document=doc, layer=self.relation_layer): - raise ValueError( - f"the relation layer {self.relation_layer} has dependent layers, " - f"cannot sort the arguments of the relations" - ) - - rel_layer.clear() + old2new_annotations = {} + new_annotations = [] for args, rel in args2relations.items(): if self.label_whitelist is not None and rel.label not in self.label_whitelist: # just add the relations whose label is not in the label whitelist (if a whitelist is present) - rel_layer.append(rel) + old2new_annotations[rel._id] = rel.copy() + new_annotations.append(old2new_annotations[rel._id]) else: - args_sorted = tuple(sorted(args, key=lambda arg: (arg.start, arg.end))) + args_sorted = sort_annotations(args) if args == args_sorted: # if the relation args are already sorted, just add the relation - rel_layer.append(rel) + old2new_annotations[rel._id] = rel.copy() + new_annotations.append(old2new_annotations[rel._id]) else: if args_sorted not in args2relations: - new_rel = construct_relation_with_new_args(rel, args_sorted) - rel_layer.append(new_rel) + old2new_annotations[rel._id] = construct_relation_with_new_args( + rel, args_sorted + ) + new_annotations.append(old2new_annotations[rel._id]) else: prev_rel = args2relations[args_sorted] if prev_rel.label != rel.label: @@ -103,5 +112,16 @@ def __call__(self, doc: D) -> D: f"do not add the new relation with sorted arguments, because it is already there: " f"{prev_rel}" ) - - return doc + # we use the previous relation with sorted arguments to re-map any annotations that + # depend on the current relation + old2new_annotations[rel._id] = prev_rel.copy() + + result = doc.copy(with_annotations=False) + result[self.relation_layer].extend(new_annotations) + result.add_all_annotations_from_other( + doc, + override_annotations={self.relation_layer: old2new_annotations}, + verbose=self.verbose, + strict=True, + ) + return result diff --git a/tests/document/processing/test_relation_argument_sorter.py b/tests/document/processing/test_relation_argument_sorter.py index 4de03ade1..f57a17e8a 100644 --- a/tests/document/processing/test_relation_argument_sorter.py +++ b/tests/document/processing/test_relation_argument_sorter.py @@ -10,6 +10,7 @@ TextDocumentWithLabeledSpansAndBinaryRelations, ) +from pie_modules.annotations import LabeledMultiSpan from pie_modules.document.processing import RelationArgumentSorter from pie_modules.document.processing.relation_argument_sorter import ( construct_relation_with_new_args, @@ -32,8 +33,7 @@ def document(): return doc -@pytest.mark.parametrize("inplace", [True, False]) -def test_relation_argument_sorter(document, inplace): +def test_relation_argument_sorter(document): # these arguments are not sorted document.binary_relations.append( BinaryRelation( @@ -47,7 +47,7 @@ def test_relation_argument_sorter(document, inplace): ) ) - arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=inplace) + arg_sorter = RelationArgumentSorter(relation_layer="binary_relations") doc_sorted_args = arg_sorter(document) assert document.text == doc_sorted_args.text @@ -64,10 +64,7 @@ def test_relation_argument_sorter(document, inplace): assert str(doc_sorted_args.binary_relations[1].tail) == "I" assert doc_sorted_args.binary_relations[1].label == "founded" - if inplace: - assert document == doc_sorted_args - else: - assert document != doc_sorted_args + assert document != doc_sorted_args @pytest.fixture @@ -140,7 +137,8 @@ def test_relation_argument_sorter_with_label_whitelist(document): # we only want to sort the relations with the label "founded" arg_sorter = RelationArgumentSorter( - relation_layer="binary_relations", label_whitelist=["founded"], inplace=False + relation_layer="binary_relations", + label_whitelist=["founded"], ) doc_sorted_args = arg_sorter(document) @@ -169,7 +167,7 @@ def test_relation_argument_sorter_sorted_rel_already_exists_with_same_label(docu ) ) - arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False) + arg_sorter = RelationArgumentSorter(relation_layer="binary_relations") caplog.clear() with caplog.at_level(logging.WARNING): @@ -206,7 +204,7 @@ def test_relation_argument_sorter_sorted_rel_already_exists_with_different_label ) ) - arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False) + arg_sorter = RelationArgumentSorter(relation_layer="binary_relations") with pytest.raises(ValueError) as excinfo: arg_sorter(document) @@ -241,16 +239,51 @@ class ExampleDocument(TextBasedDocument): doc.binary_relations.append( BinaryRelation(head=doc.labeled_spans[1], tail=doc.labeled_spans[0], label="worksAt") ) + assert str(doc.binary_relations[0].head) == "H" + assert str(doc.binary_relations[0].tail) == "Entity G" doc.relation_attributes.append( Attribute(annotation=doc.binary_relations[0], label="some_attribute") ) - arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False) + arg_sorter = RelationArgumentSorter(relation_layer="binary_relations") - with pytest.raises(ValueError) as excinfo: - arg_sorter(doc) + doc_sorted_args = arg_sorter(doc) - assert ( - str(excinfo.value) - == "the relation layer binary_relations has dependent layers, cannot sort the arguments of the relations" - ) + assert doc.text == doc_sorted_args.text + assert doc.labeled_spans == doc_sorted_args.labeled_spans + assert len(doc_sorted_args.relation_attributes) == len(doc.relation_attributes) == 1 + new_rel = doc_sorted_args.binary_relations[0] + assert str(new_rel.head) == "Entity G" + assert str(new_rel.tail) == "H" + assert len(doc_sorted_args.relation_attributes) == len(doc.relation_attributes) == 1 + assert doc_sorted_args.relation_attributes[0].annotation == new_rel + assert doc_sorted_args.relation_attributes[0].label == "some_attribute" + + +def test_relation_argument_sorter_with_labeled_multi_spans(): + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + labeled_multi_spans: AnnotationLayer[LabeledMultiSpan] = annotation_field(target="text") + binary_relations: AnnotationLayer[BinaryRelation] = annotation_field( + target="labeled_multi_spans" + ) + + doc = TestDocument(text="Karl The Big Heinz loves what he does.") + karl = LabeledMultiSpan(slices=((0, 4), (13, 18)), label="PER") + doc.labeled_multi_spans.append(karl) + assert str(karl) == "('Karl', 'Heinz')" + he = LabeledMultiSpan(slices=((30, 32),), label="PER") + doc.labeled_multi_spans.append(he) + assert str(he) == "('he',)" + doc.binary_relations.append(BinaryRelation(head=he, tail=karl, label="coref")) + + arg_sorter = RelationArgumentSorter(relation_layer="binary_relations") + doc_sorted_args = arg_sorter(doc) + + assert doc.text == doc_sorted_args.text + assert doc.labeled_multi_spans == doc_sorted_args.labeled_multi_spans + assert len(doc_sorted_args.binary_relations) == len(doc.binary_relations) == 1 + new_rel = doc_sorted_args.binary_relations[0] + assert new_rel.head == karl + assert new_rel.tail == he + assert new_rel.label == "coref"