Skip to content

Commit

Permalink
Merge pull request #73 from ArneBinder/improve_RelationArgumentSorter
Browse files Browse the repository at this point in the history
improve `RelationArgumentSorter`
  • Loading branch information
ArneBinder authored Mar 1, 2024
2 parents 60f5cfc + 1f84760 commit aaef772
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 43 deletions.
72 changes: 46 additions & 26 deletions src/pie_modules/document/processing/relation_argument_sorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import Annotation, AnnotationList, Document

from pie_modules.annotations import LabeledMultiSpan

logger = logging.getLogger(__name__)


Expand All @@ -21,6 +23,20 @@ def get_relation_args(relation: Annotation) -> tuple[Annotation, ...]:
)


def sort_annotations(annotations: tuple[Annotation, ...]) -> tuple[Annotation, ...]:
if len(annotations) <= 1:
return annotations
if all(isinstance(ann, LabeledSpan) for ann in annotations):
return tuple(sorted(annotations, key=lambda ann: (ann.start, ann.end, ann.label)))
elif all(isinstance(ann, LabeledMultiSpan) for ann in annotations):
return tuple(sorted(annotations, key=lambda ann: (ann.slices, ann.label)))
else:
raise TypeError(
f"annotations {annotations} have unknown types [{set(type(ann) for ann in annotations)}], "
f"cannot sort them"
)


def construct_relation_with_new_args(
relation: Annotation, new_args: tuple[Annotation, ...]
) -> BinaryRelation:
Expand All @@ -38,10 +54,6 @@ def construct_relation_with_new_args(
)


def has_dependent_layers(document: D, layer: str) -> bool:
return layer not in document._annotation_graph["_artificial_root"]


class RelationArgumentSorter:
"""Sorts the arguments of the relations in the given relation layer. The sorting is done by the
start and end positions of the arguments. The relations with the same sorted arguments are
Expand All @@ -50,47 +62,44 @@ class RelationArgumentSorter:
Args:
relation_layer: the name of the relation layer
label_whitelist: if not None, only the relations with the label in the whitelist are sorted
inplace: if True, the sorting is done in place, otherwise the document is copied and the sorting is done
on the copy
verbose: if True, log warnings for relations with sorted arguments that are already present
"""

def __init__(
self, relation_layer: str, label_whitelist: list[str] | None = None, inplace: bool = True
self,
relation_layer: str,
label_whitelist: list[str] | None = None,
verbose: bool = True,
):
self.relation_layer = relation_layer
self.label_whitelist = label_whitelist
self.inplace = inplace
self.verbose = verbose

def __call__(self, doc: D) -> D:
if not self.inplace:
doc = doc.copy()

rel_layer: AnnotationList[BinaryRelation] = doc[self.relation_layer]
args2relations: dict[tuple[LabeledSpan, ...], BinaryRelation] = {
get_relation_args(rel): rel for rel in rel_layer
}

# assert that no other layers depend on the relation layer
if has_dependent_layers(document=doc, layer=self.relation_layer):
raise ValueError(
f"the relation layer {self.relation_layer} has dependent layers, "
f"cannot sort the arguments of the relations"
)

rel_layer.clear()
old2new_annotations = {}
new_annotations = []
for args, rel in args2relations.items():
if self.label_whitelist is not None and rel.label not in self.label_whitelist:
# just add the relations whose label is not in the label whitelist (if a whitelist is present)
rel_layer.append(rel)
old2new_annotations[rel._id] = rel.copy()
new_annotations.append(old2new_annotations[rel._id])
else:
args_sorted = tuple(sorted(args, key=lambda arg: (arg.start, arg.end)))
args_sorted = sort_annotations(args)
if args == args_sorted:
# if the relation args are already sorted, just add the relation
rel_layer.append(rel)
old2new_annotations[rel._id] = rel.copy()
new_annotations.append(old2new_annotations[rel._id])
else:
if args_sorted not in args2relations:
new_rel = construct_relation_with_new_args(rel, args_sorted)
rel_layer.append(new_rel)
old2new_annotations[rel._id] = construct_relation_with_new_args(
rel, args_sorted
)
new_annotations.append(old2new_annotations[rel._id])
else:
prev_rel = args2relations[args_sorted]
if prev_rel.label != rel.label:
Expand All @@ -103,5 +112,16 @@ def __call__(self, doc: D) -> D:
f"do not add the new relation with sorted arguments, because it is already there: "
f"{prev_rel}"
)

return doc
# we use the previous relation with sorted arguments to re-map any annotations that
# depend on the current relation
old2new_annotations[rel._id] = prev_rel.copy()

result = doc.copy(with_annotations=False)
result[self.relation_layer].extend(new_annotations)
result.add_all_annotations_from_other(
doc,
override_annotations={self.relation_layer: old2new_annotations},
verbose=self.verbose,
strict=True,
)
return result
67 changes: 50 additions & 17 deletions tests/document/processing/test_relation_argument_sorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TextDocumentWithLabeledSpansAndBinaryRelations,
)

from pie_modules.annotations import LabeledMultiSpan
from pie_modules.document.processing import RelationArgumentSorter
from pie_modules.document.processing.relation_argument_sorter import (
construct_relation_with_new_args,
Expand All @@ -32,8 +33,7 @@ def document():
return doc


@pytest.mark.parametrize("inplace", [True, False])
def test_relation_argument_sorter(document, inplace):
def test_relation_argument_sorter(document):
# these arguments are not sorted
document.binary_relations.append(
BinaryRelation(
Expand All @@ -47,7 +47,7 @@ def test_relation_argument_sorter(document, inplace):
)
)

arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=inplace)
arg_sorter = RelationArgumentSorter(relation_layer="binary_relations")
doc_sorted_args = arg_sorter(document)

assert document.text == doc_sorted_args.text
Expand All @@ -64,10 +64,7 @@ def test_relation_argument_sorter(document, inplace):
assert str(doc_sorted_args.binary_relations[1].tail) == "I"
assert doc_sorted_args.binary_relations[1].label == "founded"

if inplace:
assert document == doc_sorted_args
else:
assert document != doc_sorted_args
assert document != doc_sorted_args


@pytest.fixture
Expand Down Expand Up @@ -140,7 +137,8 @@ def test_relation_argument_sorter_with_label_whitelist(document):

# we only want to sort the relations with the label "founded"
arg_sorter = RelationArgumentSorter(
relation_layer="binary_relations", label_whitelist=["founded"], inplace=False
relation_layer="binary_relations",
label_whitelist=["founded"],
)
doc_sorted_args = arg_sorter(document)

Expand Down Expand Up @@ -169,7 +167,7 @@ def test_relation_argument_sorter_sorted_rel_already_exists_with_same_label(docu
)
)

arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False)
arg_sorter = RelationArgumentSorter(relation_layer="binary_relations")

caplog.clear()
with caplog.at_level(logging.WARNING):
Expand Down Expand Up @@ -206,7 +204,7 @@ def test_relation_argument_sorter_sorted_rel_already_exists_with_different_label
)
)

arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False)
arg_sorter = RelationArgumentSorter(relation_layer="binary_relations")

with pytest.raises(ValueError) as excinfo:
arg_sorter(document)
Expand Down Expand Up @@ -241,16 +239,51 @@ class ExampleDocument(TextBasedDocument):
doc.binary_relations.append(
BinaryRelation(head=doc.labeled_spans[1], tail=doc.labeled_spans[0], label="worksAt")
)
assert str(doc.binary_relations[0].head) == "H"
assert str(doc.binary_relations[0].tail) == "Entity G"
doc.relation_attributes.append(
Attribute(annotation=doc.binary_relations[0], label="some_attribute")
)

arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False)
arg_sorter = RelationArgumentSorter(relation_layer="binary_relations")

with pytest.raises(ValueError) as excinfo:
arg_sorter(doc)
doc_sorted_args = arg_sorter(doc)

assert (
str(excinfo.value)
== "the relation layer binary_relations has dependent layers, cannot sort the arguments of the relations"
)
assert doc.text == doc_sorted_args.text
assert doc.labeled_spans == doc_sorted_args.labeled_spans
assert len(doc_sorted_args.relation_attributes) == len(doc.relation_attributes) == 1
new_rel = doc_sorted_args.binary_relations[0]
assert str(new_rel.head) == "Entity G"
assert str(new_rel.tail) == "H"
assert len(doc_sorted_args.relation_attributes) == len(doc.relation_attributes) == 1
assert doc_sorted_args.relation_attributes[0].annotation == new_rel
assert doc_sorted_args.relation_attributes[0].label == "some_attribute"


def test_relation_argument_sorter_with_labeled_multi_spans():
@dataclasses.dataclass
class TestDocument(TextBasedDocument):
labeled_multi_spans: AnnotationLayer[LabeledMultiSpan] = annotation_field(target="text")
binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(
target="labeled_multi_spans"
)

doc = TestDocument(text="Karl The Big Heinz loves what he does.")
karl = LabeledMultiSpan(slices=((0, 4), (13, 18)), label="PER")
doc.labeled_multi_spans.append(karl)
assert str(karl) == "('Karl', 'Heinz')"
he = LabeledMultiSpan(slices=((30, 32),), label="PER")
doc.labeled_multi_spans.append(he)
assert str(he) == "('he',)"
doc.binary_relations.append(BinaryRelation(head=he, tail=karl, label="coref"))

arg_sorter = RelationArgumentSorter(relation_layer="binary_relations")
doc_sorted_args = arg_sorter(doc)

assert doc.text == doc_sorted_args.text
assert doc.labeled_multi_spans == doc_sorted_args.labeled_multi_spans
assert len(doc_sorted_args.binary_relations) == len(doc.binary_relations) == 1
new_rel = doc_sorted_args.binary_relations[0]
assert new_rel.head == karl
assert new_rel.tail == he
assert new_rel.label == "coref"

0 comments on commit aaef772

Please sign in to comment.