From 263ba8bd47343b1bca62b1dc24d0715ff99fa070 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 29 Feb 2024 16:08:28 +0100 Subject: [PATCH 1/7] handle dependent annotations in RelationArgumentSorter --- .../processing/relation_argument_sorter.py | 53 +++++++++++++------ 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/src/pie_modules/document/processing/relation_argument_sorter.py b/src/pie_modules/document/processing/relation_argument_sorter.py index a6e3ad99b..abe8529d1 100644 --- a/src/pie_modules/document/processing/relation_argument_sorter.py +++ b/src/pie_modules/document/processing/relation_argument_sorter.py @@ -50,47 +50,55 @@ class RelationArgumentSorter: Args: relation_layer: the name of the relation layer label_whitelist: if not None, only the relations with the label in the whitelist are sorted - inplace: if True, the sorting is done in place, otherwise the document is copied and the sorting is done - on the copy + verbose: if True, log warnings for relations with sorted arguments that are already present """ def __init__( - self, relation_layer: str, label_whitelist: list[str] | None = None, inplace: bool = True + self, + relation_layer: str, + label_whitelist: list[str] | None = None, + verbose: bool = True, ): self.relation_layer = relation_layer self.label_whitelist = label_whitelist - self.inplace = inplace + self.verbose = verbose def __call__(self, doc: D) -> D: - if not self.inplace: - doc = doc.copy() + # if not self.inplace: + # doc = doc.copy() rel_layer: AnnotationList[BinaryRelation] = doc[self.relation_layer] args2relations: dict[tuple[LabeledSpan, ...], BinaryRelation] = { get_relation_args(rel): rel for rel in rel_layer } + old2new_annotations = {} + # removed_annotation_ids = [] + # assert that no other layers depend on the relation layer - if has_dependent_layers(document=doc, layer=self.relation_layer): - raise ValueError( - f"the relation layer {self.relation_layer} has dependent layers, " - f"cannot sort the arguments of the relations" - ) + # if has_dependent_layers(document=doc, layer=self.relation_layer): + # raise ValueError( + # f"the relation layer {self.relation_layer} has dependent layers, " + # f"cannot sort the arguments of the relations" + # ) - rel_layer.clear() + # rel_layer.clear() for args, rel in args2relations.items(): if self.label_whitelist is not None and rel.label not in self.label_whitelist: # just add the relations whose label is not in the label whitelist (if a whitelist is present) - rel_layer.append(rel) + # rel_layer.append(rel) + old2new_annotations[rel._id] = rel else: args_sorted = tuple(sorted(args, key=lambda arg: (arg.start, arg.end))) if args == args_sorted: # if the relation args are already sorted, just add the relation - rel_layer.append(rel) + # rel_layer.append(rel) + old2new_annotations[rel._id] = rel else: if args_sorted not in args2relations: new_rel = construct_relation_with_new_args(rel, args_sorted) - rel_layer.append(new_rel) + # rel_layer.append(new_rel) + old2new_annotations[rel._id] = new_rel else: prev_rel = args2relations[args_sorted] if prev_rel.label != rel.label: @@ -103,5 +111,16 @@ def __call__(self, doc: D) -> D: f"do not add the new relation with sorted arguments, because it is already there: " f"{prev_rel}" ) - - return doc + # removed_annotation_ids.append(rel._id) + old2new_annotations[rel._id] = prev_rel + + result = doc.copy(with_annotations=False) + result[self.relation_layer].extend(old2new_annotations.values()) + result.add_all_annotations_from_other( + doc, + override_annotations={self.relation_layer: old2new_annotations}, + # removed_annotations={self.relation_layer: set(removed_annotation_ids)}, + verbose=self.verbose, + strict=True, + ) + return result From 3445afeb301067ee65662358e68589560048d467 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 29 Feb 2024 17:48:05 +0100 Subject: [PATCH 2/7] some fixes --- .../processing/relation_argument_sorter.py | 6 +++--- .../test_relation_argument_sorter.py | 19 ++++++++----------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/pie_modules/document/processing/relation_argument_sorter.py b/src/pie_modules/document/processing/relation_argument_sorter.py index abe8529d1..fa4a060fb 100644 --- a/src/pie_modules/document/processing/relation_argument_sorter.py +++ b/src/pie_modules/document/processing/relation_argument_sorter.py @@ -87,13 +87,13 @@ def __call__(self, doc: D) -> D: if self.label_whitelist is not None and rel.label not in self.label_whitelist: # just add the relations whose label is not in the label whitelist (if a whitelist is present) # rel_layer.append(rel) - old2new_annotations[rel._id] = rel + old2new_annotations[rel._id] = rel.copy() else: args_sorted = tuple(sorted(args, key=lambda arg: (arg.start, arg.end))) if args == args_sorted: # if the relation args are already sorted, just add the relation # rel_layer.append(rel) - old2new_annotations[rel._id] = rel + old2new_annotations[rel._id] = rel.copy() else: if args_sorted not in args2relations: new_rel = construct_relation_with_new_args(rel, args_sorted) @@ -112,7 +112,7 @@ def __call__(self, doc: D) -> D: f"{prev_rel}" ) # removed_annotation_ids.append(rel._id) - old2new_annotations[rel._id] = prev_rel + old2new_annotations[rel._id] = prev_rel.copy() result = doc.copy(with_annotations=False) result[self.relation_layer].extend(old2new_annotations.values()) diff --git a/tests/document/processing/test_relation_argument_sorter.py b/tests/document/processing/test_relation_argument_sorter.py index 4de03ade1..f73cbc49d 100644 --- a/tests/document/processing/test_relation_argument_sorter.py +++ b/tests/document/processing/test_relation_argument_sorter.py @@ -32,8 +32,7 @@ def document(): return doc -@pytest.mark.parametrize("inplace", [True, False]) -def test_relation_argument_sorter(document, inplace): +def test_relation_argument_sorter(document): # these arguments are not sorted document.binary_relations.append( BinaryRelation( @@ -47,7 +46,7 @@ def test_relation_argument_sorter(document, inplace): ) ) - arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=inplace) + arg_sorter = RelationArgumentSorter(relation_layer="binary_relations") doc_sorted_args = arg_sorter(document) assert document.text == doc_sorted_args.text @@ -64,10 +63,7 @@ def test_relation_argument_sorter(document, inplace): assert str(doc_sorted_args.binary_relations[1].tail) == "I" assert doc_sorted_args.binary_relations[1].label == "founded" - if inplace: - assert document == doc_sorted_args - else: - assert document != doc_sorted_args + assert document != doc_sorted_args @pytest.fixture @@ -140,7 +136,8 @@ def test_relation_argument_sorter_with_label_whitelist(document): # we only want to sort the relations with the label "founded" arg_sorter = RelationArgumentSorter( - relation_layer="binary_relations", label_whitelist=["founded"], inplace=False + relation_layer="binary_relations", + label_whitelist=["founded"], ) doc_sorted_args = arg_sorter(document) @@ -169,7 +166,7 @@ def test_relation_argument_sorter_sorted_rel_already_exists_with_same_label(docu ) ) - arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False) + arg_sorter = RelationArgumentSorter(relation_layer="binary_relations") caplog.clear() with caplog.at_level(logging.WARNING): @@ -206,7 +203,7 @@ def test_relation_argument_sorter_sorted_rel_already_exists_with_different_label ) ) - arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False) + arg_sorter = RelationArgumentSorter(relation_layer="binary_relations") with pytest.raises(ValueError) as excinfo: arg_sorter(document) @@ -245,7 +242,7 @@ class ExampleDocument(TextBasedDocument): Attribute(annotation=doc.binary_relations[0], label="some_attribute") ) - arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False) + arg_sorter = RelationArgumentSorter(relation_layer="binary_relations") with pytest.raises(ValueError) as excinfo: arg_sorter(doc) From 51285ad0fc7c9822a1686928ceb11af0d4d4fec6 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 29 Feb 2024 17:51:33 +0100 Subject: [PATCH 3/7] deduplicate annotations --- .../document/processing/relation_argument_sorter.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/pie_modules/document/processing/relation_argument_sorter.py b/src/pie_modules/document/processing/relation_argument_sorter.py index fa4a060fb..c73826076 100644 --- a/src/pie_modules/document/processing/relation_argument_sorter.py +++ b/src/pie_modules/document/processing/relation_argument_sorter.py @@ -115,7 +115,11 @@ def __call__(self, doc: D) -> D: old2new_annotations[rel._id] = prev_rel.copy() result = doc.copy(with_annotations=False) - result[self.relation_layer].extend(old2new_annotations.values()) + annotations_deduplicated = [] + for annotation in old2new_annotations.values(): + if annotation not in annotations_deduplicated: + annotations_deduplicated.append(annotation) + result[self.relation_layer].extend(annotations_deduplicated) result.add_all_annotations_from_other( doc, override_annotations={self.relation_layer: old2new_annotations}, From 69de1e0fd5388270f1accbf6de00fe1d63b97f4a Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 1 Mar 2024 16:25:40 +0100 Subject: [PATCH 4/7] do not deplicate new rels, but just collect rels to add separately --- .../processing/relation_argument_sorter.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/pie_modules/document/processing/relation_argument_sorter.py b/src/pie_modules/document/processing/relation_argument_sorter.py index c73826076..b1fe0dda7 100644 --- a/src/pie_modules/document/processing/relation_argument_sorter.py +++ b/src/pie_modules/document/processing/relation_argument_sorter.py @@ -83,22 +83,28 @@ def __call__(self, doc: D) -> D: # ) # rel_layer.clear() + new_annotations = [] for args, rel in args2relations.items(): if self.label_whitelist is not None and rel.label not in self.label_whitelist: # just add the relations whose label is not in the label whitelist (if a whitelist is present) # rel_layer.append(rel) - old2new_annotations[rel._id] = rel.copy() + new_annotation = rel.copy() + old2new_annotations[rel._id] = new_annotation + new_annotations.append(new_annotation) else: args_sorted = tuple(sorted(args, key=lambda arg: (arg.start, arg.end))) if args == args_sorted: # if the relation args are already sorted, just add the relation # rel_layer.append(rel) - old2new_annotations[rel._id] = rel.copy() + new_annotation = rel.copy() + old2new_annotations[rel._id] = new_annotation + new_annotations.append(new_annotation) else: if args_sorted not in args2relations: - new_rel = construct_relation_with_new_args(rel, args_sorted) - # rel_layer.append(new_rel) - old2new_annotations[rel._id] = new_rel + new_annotation = construct_relation_with_new_args(rel, args_sorted) + # rel_layer.append(new_annotation) + old2new_annotations[rel._id] = new_annotation + new_annotations.append(new_annotation) else: prev_rel = args2relations[args_sorted] if prev_rel.label != rel.label: @@ -115,11 +121,7 @@ def __call__(self, doc: D) -> D: old2new_annotations[rel._id] = prev_rel.copy() result = doc.copy(with_annotations=False) - annotations_deduplicated = [] - for annotation in old2new_annotations.values(): - if annotation not in annotations_deduplicated: - annotations_deduplicated.append(annotation) - result[self.relation_layer].extend(annotations_deduplicated) + result[self.relation_layer].extend(new_annotations) result.add_all_annotations_from_other( doc, override_annotations={self.relation_layer: old2new_annotations}, From 0141be83692009f6eba301c3dea8c7b208055623 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 1 Mar 2024 16:26:00 +0100 Subject: [PATCH 5/7] fix test_relation_argument_sorter_with_dependent_layers() --- .../test_relation_argument_sorter.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/document/processing/test_relation_argument_sorter.py b/tests/document/processing/test_relation_argument_sorter.py index f73cbc49d..0c82190b4 100644 --- a/tests/document/processing/test_relation_argument_sorter.py +++ b/tests/document/processing/test_relation_argument_sorter.py @@ -238,16 +238,22 @@ class ExampleDocument(TextBasedDocument): doc.binary_relations.append( BinaryRelation(head=doc.labeled_spans[1], tail=doc.labeled_spans[0], label="worksAt") ) + assert str(doc.binary_relations[0].head) == "H" + assert str(doc.binary_relations[0].tail) == "Entity G" doc.relation_attributes.append( Attribute(annotation=doc.binary_relations[0], label="some_attribute") ) arg_sorter = RelationArgumentSorter(relation_layer="binary_relations") - with pytest.raises(ValueError) as excinfo: - arg_sorter(doc) - - assert ( - str(excinfo.value) - == "the relation layer binary_relations has dependent layers, cannot sort the arguments of the relations" - ) + doc_sorted_args = arg_sorter(doc) + + assert doc.text == doc_sorted_args.text + assert doc.labeled_spans == doc_sorted_args.labeled_spans + assert len(doc_sorted_args.relation_attributes) == len(doc.relation_attributes) == 1 + new_rel = doc_sorted_args.binary_relations[0] + assert str(new_rel.head) == "Entity G" + assert str(new_rel.tail) == "H" + assert len(doc_sorted_args.relation_attributes) == len(doc.relation_attributes) == 1 + assert doc_sorted_args.relation_attributes[0].annotation == new_rel + assert doc_sorted_args.relation_attributes[0].label == "some_attribute" From 5e90650d8abfcdf55946cc4d6b9188f1b12a9447 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 1 Mar 2024 17:02:22 +0100 Subject: [PATCH 6/7] cleanup --- .../processing/relation_argument_sorter.py | 41 +++++-------------- 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/src/pie_modules/document/processing/relation_argument_sorter.py b/src/pie_modules/document/processing/relation_argument_sorter.py index b1fe0dda7..05da08e95 100644 --- a/src/pie_modules/document/processing/relation_argument_sorter.py +++ b/src/pie_modules/document/processing/relation_argument_sorter.py @@ -38,10 +38,6 @@ def construct_relation_with_new_args( ) -def has_dependent_layers(document: D, layer: str) -> bool: - return layer not in document._annotation_graph["_artificial_root"] - - class RelationArgumentSorter: """Sorts the arguments of the relations in the given relation layer. The sorting is done by the start and end positions of the arguments. The relations with the same sorted arguments are @@ -64,47 +60,30 @@ def __init__( self.verbose = verbose def __call__(self, doc: D) -> D: - # if not self.inplace: - # doc = doc.copy() - rel_layer: AnnotationList[BinaryRelation] = doc[self.relation_layer] args2relations: dict[tuple[LabeledSpan, ...], BinaryRelation] = { get_relation_args(rel): rel for rel in rel_layer } old2new_annotations = {} - # removed_annotation_ids = [] - - # assert that no other layers depend on the relation layer - # if has_dependent_layers(document=doc, layer=self.relation_layer): - # raise ValueError( - # f"the relation layer {self.relation_layer} has dependent layers, " - # f"cannot sort the arguments of the relations" - # ) - - # rel_layer.clear() new_annotations = [] for args, rel in args2relations.items(): if self.label_whitelist is not None and rel.label not in self.label_whitelist: # just add the relations whose label is not in the label whitelist (if a whitelist is present) - # rel_layer.append(rel) - new_annotation = rel.copy() - old2new_annotations[rel._id] = new_annotation - new_annotations.append(new_annotation) + old2new_annotations[rel._id] = rel.copy() + new_annotations.append(old2new_annotations[rel._id]) else: args_sorted = tuple(sorted(args, key=lambda arg: (arg.start, arg.end))) if args == args_sorted: # if the relation args are already sorted, just add the relation - # rel_layer.append(rel) - new_annotation = rel.copy() - old2new_annotations[rel._id] = new_annotation - new_annotations.append(new_annotation) + old2new_annotations[rel._id] = rel.copy() + new_annotations.append(old2new_annotations[rel._id]) else: if args_sorted not in args2relations: - new_annotation = construct_relation_with_new_args(rel, args_sorted) - # rel_layer.append(new_annotation) - old2new_annotations[rel._id] = new_annotation - new_annotations.append(new_annotation) + old2new_annotations[rel._id] = construct_relation_with_new_args( + rel, args_sorted + ) + new_annotations.append(old2new_annotations[rel._id]) else: prev_rel = args2relations[args_sorted] if prev_rel.label != rel.label: @@ -117,7 +96,8 @@ def __call__(self, doc: D) -> D: f"do not add the new relation with sorted arguments, because it is already there: " f"{prev_rel}" ) - # removed_annotation_ids.append(rel._id) + # we use the previous relation with sorted arguments to re-map any annotations that + # depend on the current relation old2new_annotations[rel._id] = prev_rel.copy() result = doc.copy(with_annotations=False) @@ -125,7 +105,6 @@ def __call__(self, doc: D) -> D: result.add_all_annotations_from_other( doc, override_annotations={self.relation_layer: old2new_annotations}, - # removed_annotations={self.relation_layer: set(removed_annotation_ids)}, verbose=self.verbose, strict=True, ) From 1f8476095ff1e2b045c36f1df8b7f11246032a85 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 1 Mar 2024 17:23:27 +0100 Subject: [PATCH 7/7] allow for LabeledMultiSpans --- .../processing/relation_argument_sorter.py | 18 ++++++++++- .../test_relation_argument_sorter.py | 30 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/src/pie_modules/document/processing/relation_argument_sorter.py b/src/pie_modules/document/processing/relation_argument_sorter.py index 05da08e95..b5d9aac87 100644 --- a/src/pie_modules/document/processing/relation_argument_sorter.py +++ b/src/pie_modules/document/processing/relation_argument_sorter.py @@ -6,6 +6,8 @@ from pytorch_ie.annotations import BinaryRelation, LabeledSpan from pytorch_ie.core import Annotation, AnnotationList, Document +from pie_modules.annotations import LabeledMultiSpan + logger = logging.getLogger(__name__) @@ -21,6 +23,20 @@ def get_relation_args(relation: Annotation) -> tuple[Annotation, ...]: ) +def sort_annotations(annotations: tuple[Annotation, ...]) -> tuple[Annotation, ...]: + if len(annotations) <= 1: + return annotations + if all(isinstance(ann, LabeledSpan) for ann in annotations): + return tuple(sorted(annotations, key=lambda ann: (ann.start, ann.end, ann.label))) + elif all(isinstance(ann, LabeledMultiSpan) for ann in annotations): + return tuple(sorted(annotations, key=lambda ann: (ann.slices, ann.label))) + else: + raise TypeError( + f"annotations {annotations} have unknown types [{set(type(ann) for ann in annotations)}], " + f"cannot sort them" + ) + + def construct_relation_with_new_args( relation: Annotation, new_args: tuple[Annotation, ...] ) -> BinaryRelation: @@ -73,7 +89,7 @@ def __call__(self, doc: D) -> D: old2new_annotations[rel._id] = rel.copy() new_annotations.append(old2new_annotations[rel._id]) else: - args_sorted = tuple(sorted(args, key=lambda arg: (arg.start, arg.end))) + args_sorted = sort_annotations(args) if args == args_sorted: # if the relation args are already sorted, just add the relation old2new_annotations[rel._id] = rel.copy() diff --git a/tests/document/processing/test_relation_argument_sorter.py b/tests/document/processing/test_relation_argument_sorter.py index 0c82190b4..f57a17e8a 100644 --- a/tests/document/processing/test_relation_argument_sorter.py +++ b/tests/document/processing/test_relation_argument_sorter.py @@ -10,6 +10,7 @@ TextDocumentWithLabeledSpansAndBinaryRelations, ) +from pie_modules.annotations import LabeledMultiSpan from pie_modules.document.processing import RelationArgumentSorter from pie_modules.document.processing.relation_argument_sorter import ( construct_relation_with_new_args, @@ -257,3 +258,32 @@ class ExampleDocument(TextBasedDocument): assert len(doc_sorted_args.relation_attributes) == len(doc.relation_attributes) == 1 assert doc_sorted_args.relation_attributes[0].annotation == new_rel assert doc_sorted_args.relation_attributes[0].label == "some_attribute" + + +def test_relation_argument_sorter_with_labeled_multi_spans(): + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + labeled_multi_spans: AnnotationLayer[LabeledMultiSpan] = annotation_field(target="text") + binary_relations: AnnotationLayer[BinaryRelation] = annotation_field( + target="labeled_multi_spans" + ) + + doc = TestDocument(text="Karl The Big Heinz loves what he does.") + karl = LabeledMultiSpan(slices=((0, 4), (13, 18)), label="PER") + doc.labeled_multi_spans.append(karl) + assert str(karl) == "('Karl', 'Heinz')" + he = LabeledMultiSpan(slices=((30, 32),), label="PER") + doc.labeled_multi_spans.append(he) + assert str(he) == "('he',)" + doc.binary_relations.append(BinaryRelation(head=he, tail=karl, label="coref")) + + arg_sorter = RelationArgumentSorter(relation_layer="binary_relations") + doc_sorted_args = arg_sorter(doc) + + assert doc.text == doc_sorted_args.text + assert doc.labeled_multi_spans == doc_sorted_args.labeled_multi_spans + assert len(doc_sorted_args.binary_relations) == len(doc.binary_relations) == 1 + new_rel = doc_sorted_args.binary_relations[0] + assert new_rel.head == karl + assert new_rel.tail == he + assert new_rel.label == "coref"