diff --git a/poetry.lock b/poetry.lock index b1be5fc4d..5c190fca9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1943,4 +1943,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "e6a26cb8343b9be100a3d3a66471077b8abc5df470a33d05f33450b2d0e0a35b" +content-hash = "aecdcdc068d62a0c1630df94a488c481242188d73028084e81a1abc83357b762" diff --git a/pyproject.toml b/pyproject.toml index 282ea1610..47c1b1463 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,8 @@ pytorch-ie = ">=0.29.8,<0.30.0" pytorch-lightning = "^2.1.0" torchmetrics = "^1" pytorch-crf = ">=0.7.2" +# for SpansViaRelationMerger +networkx = "^3.0.0" # because of BartModelWithDecoderPositionIds transformers = "^4.35.0" diff --git a/src/pie_modules/document/processing/__init__.py b/src/pie_modules/document/processing/__init__.py index f10a62d53..5f3c49e9b 100644 --- a/src/pie_modules/document/processing/__init__.py +++ b/src/pie_modules/document/processing/__init__.py @@ -1,3 +1,4 @@ +from .merge_spans_via_relation import SpansViaRelationMerger from .regex_partitioner import RegexPartitioner from .relation_argument_sorter import RelationArgumentSorter from .text_span_trimmer import TextSpanTrimmer diff --git a/src/pie_modules/document/processing/merge_spans_via_relation.py b/src/pie_modules/document/processing/merge_spans_via_relation.py new file mode 100644 index 000000000..efd8f72f9 --- /dev/null +++ b/src/pie_modules/document/processing/merge_spans_via_relation.py @@ -0,0 +1,183 @@ +import logging +from typing import Optional, Sequence, Set, Tuple, TypeVar, Union + +import networkx as nx +from pytorch_ie import AnnotationLayer +from pytorch_ie.core import Document + +from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan +from pie_modules.utils import resolve_type + +logger = logging.getLogger(__name__) + + +D = TypeVar("D", bound=Document) + + +def _merge_spans_via_relation( + spans: Sequence[LabeledSpan], + relations: Sequence[BinaryRelation], + link_relation_label: str, + create_multi_spans: bool = True, +) -> Tuple[Union[Set[LabeledSpan], Set[LabeledMultiSpan]], Set[BinaryRelation]]: + # convert list of relations to a graph to easily calculate connected components to merge + g = nx.Graph() + link_relations = [] + other_relations = [] + for rel in relations: + if rel.label == link_relation_label: + link_relations.append(rel) + # never merge spans that have not the same label + if ( + not (isinstance(rel.head, LabeledSpan) or isinstance(rel.tail, LabeledSpan)) + or rel.head.label == rel.tail.label + ): + g.add_edge(rel.head, rel.tail) + else: + logger.debug( + f"spans to merge do not have the same label, do not merge them: {rel.head}, {rel.tail}" + ) + else: + other_relations.append(rel) + + span_mapping = {} + connected_components: Set[LabeledSpan] + for connected_components in nx.connected_components(g): + # all spans in a connected component have the same label + label = list(span.label for span in connected_components)[0] + connected_components_sorted = sorted(connected_components, key=lambda span: span.start) + if create_multi_spans: + new_span = LabeledMultiSpan( + slices=tuple((span.start, span.end) for span in connected_components_sorted), + label=label, + ) + else: + new_span = LabeledSpan( + start=min(span.start for span in connected_components_sorted), + end=max(span.end for span in connected_components_sorted), + label=label, + ) + for span in connected_components_sorted: + span_mapping[span] = new_span + for span in spans: + if span not in span_mapping: + if create_multi_spans: + span_mapping[span] = LabeledMultiSpan( + slices=((span.start, span.end),), label=span.label, score=span.score + ) + else: + span_mapping[span] = LabeledSpan( + start=span.start, end=span.end, label=span.label, score=span.score + ) + + new_spans = set(span_mapping.values()) + new_relations = { + BinaryRelation( + head=span_mapping[rel.head], + tail=span_mapping[rel.tail], + label=rel.label, + score=rel.score, + ) + for rel in other_relations + } + + return new_spans, new_relations + + +class SpansViaRelationMerger: + """Merge spans based on relations. + + This processor merges spans based on binary relations. The spans are merged into a + single span if they are connected via a relation with the specified link label. The + processor handles both gold and predicted annotations. + + Args: + relation_layer: The name of the relation layer in the document. + link_relation_label: The label of the relation that should be used to merge spans. + create_multi_spans: Whether to create multi spans or not. If `True`, multi spans + will be created, otherwise single spans that cover the merged spans will be + created. + result_document_type: The type of the document to return. This can be a class or + a string that can be resolved to a class. The class must be a subclass of + `Document`. Required when `create_multi_spans` is `True`. + result_field_mapping: A mapping from the field names in the input document to the + field names in the result document. This is used to copy over fields from the + input document to the result document. The keys are the field names in the + input document and the values are the field names in the result document. + Required when `result_document_type` is provided. + use_predicted_spans: Whether to use the predicted spans or the gold spans when + processing predictions. + """ + + def __init__( + self, + relation_layer: str, + link_relation_label: str, + result_document_type: Optional[Union[type[Document], str]] = None, + result_field_mapping: Optional[dict[str, str]] = None, + create_multi_spans: bool = True, + use_predicted_spans: bool = True, + ): + self.relation_layer = relation_layer + self.link_relation_label = link_relation_label + self.create_multi_spans = create_multi_spans + if self.create_multi_spans: + if result_document_type is None: + raise ValueError( + "result_document_type must be set when create_multi_spans is True" + ) + self.result_document_type: Optional[type[Document]] + if result_document_type is not None: + if result_field_mapping is None: + raise ValueError( + "result_field_mapping must be set when result_document_type is provided" + ) + self.result_document_type = resolve_type( + result_document_type, expected_super_type=Document + ) + else: + self.result_document_type = None + self.result_field_mapping = result_field_mapping or {} + self.use_predicted_spans = use_predicted_spans + + def __call__(self, document: D) -> D: + relations: AnnotationLayer[BinaryRelation] = document[self.relation_layer] + spans: AnnotationLayer[LabeledSpan] = document[self.relation_layer].target_layer + + # process gold annotations + new_gold_spans, new_gold_relations = _merge_spans_via_relation( + spans=spans, + relations=relations, + link_relation_label=self.link_relation_label, + create_multi_spans=self.create_multi_spans, + ) + + # process predicted annotations + new_pred_spans, new_pred_relations = _merge_spans_via_relation( + spans=spans.predictions if self.use_predicted_spans else spans, + relations=relations.predictions, + link_relation_label=self.link_relation_label, + create_multi_spans=self.create_multi_spans, + ) + + result = document.copy(with_annotations=False) + if self.result_document_type is not None: + result = result.as_type(new_type=self.result_document_type) + span_layer_name = document[self.relation_layer].target_name + result_span_layer_name = self.result_field_mapping.get(span_layer_name, span_layer_name) + result_relation_layer_name = self.result_field_mapping.get( + self.relation_layer, self.relation_layer + ) + result[result_span_layer_name].extend(new_gold_spans) + result[result_relation_layer_name].extend(new_gold_relations) + result[result_span_layer_name].predictions.extend(new_pred_spans) + result[result_relation_layer_name].predictions.extend(new_pred_relations) + + # copy over remaining fields mentioned in result_field_mapping + for field_name, result_field_name in self.result_field_mapping.items(): + if field_name not in [span_layer_name, self.relation_layer]: + for ann in document[field_name]: + result[result_field_name].append(ann.copy()) + for ann in document[field_name].predictions: + result[result_field_name].predictions.append(ann.copy()) + return result diff --git a/tests/document/processing/test_merge_spans_via_relation.py b/tests/document/processing/test_merge_spans_via_relation.py new file mode 100644 index 000000000..ec2ccd56d --- /dev/null +++ b/tests/document/processing/test_merge_spans_via_relation.py @@ -0,0 +1,199 @@ +import pytest +from pytorch_ie.documents import ( + TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, +) + +from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan +from pie_modules.document.processing import SpansViaRelationMerger +from pie_modules.document.processing.merge_spans_via_relation import ( + _merge_spans_via_relation, +) +from pie_modules.documents import ( + TextDocumentWithLabeledMultiSpansAndBinaryRelations, + TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, + TextDocumentWithLabeledSpansAndBinaryRelations, +) + + +@pytest.mark.parametrize( + "create_multi_spans", + [False, True], +) +def test_merge_spans_via_relation(create_multi_spans: bool): + # we have 6 spans and 4 relations + # spans 0, 2, 4 are connected via "link" relation, so they should be merged + # spans 3, 5 are connected via "link" relation, but they do not have the same label, + # so they should not be merged. But the relation should be removed + # spans 0, 3 are connected via "relation_x" relation, its head should be remapped to the new span + spans = [ + LabeledSpan(start=0, end=1, label="label_a"), + LabeledSpan(start=2, end=3, label="other"), + LabeledSpan(start=4, end=5, label="label_a"), + LabeledSpan(start=6, end=7, label="label_b"), + LabeledSpan(start=8, end=9, label="label_a"), + LabeledSpan(start=10, end=11, label="label_c"), + ] + relations = [ + BinaryRelation(head=spans[0], tail=spans[2], label="link"), + BinaryRelation(head=spans[0], tail=spans[3], label="relation_x"), + BinaryRelation(head=spans[2], tail=spans[4], label="link"), + BinaryRelation(head=spans[3], tail=spans[5], label="link"), + ] + + merged_spans, merged_relations = _merge_spans_via_relation( + spans=spans, + relations=relations, + link_relation_label="link", + create_multi_spans=create_multi_spans, + ) + if create_multi_spans: + head = LabeledMultiSpan( + slices=( + (0, 1), + (4, 5), + (8, 9), + ), + label="label_a", + ) + tail = LabeledMultiSpan(slices=((6, 7),), label="label_b") + assert merged_spans == { + head, + LabeledMultiSpan(slices=((2, 3),), label="other"), + tail, + LabeledMultiSpan(slices=((10, 11),), label="label_c"), + } + else: + head = LabeledSpan(start=0, end=9, label="label_a") + tail = LabeledSpan(start=6, end=7, label="label_b") + assert merged_spans == { + head, + tail, + LabeledSpan(start=2, end=3, label="other"), + LabeledSpan(start=10, end=11, label="label_c"), + } + assert merged_relations == {BinaryRelation(head=head, tail=tail, label="relation_x")} + + +def sort_spans(spans): + if len(spans) == 0: + return [] + if isinstance(spans[0], LabeledSpan): + return sorted(spans, key=lambda span: (span.start, span.end, span.label)) + else: + return sorted(spans, key=lambda span: (span.slices, span.label)) + + +def resolve_spans(spans): + if len(spans) == 0: + return [] + if isinstance(spans[0], LabeledSpan): + return [(span.target[span.start : span.end], span.label) for span in spans] + else: + return [ + (tuple(span.target[start:end] for start, end in span.slices), span.label) + for span in spans + ] + + +@pytest.mark.parametrize("create_multi_spans", [False, True]) +def test_spans_via_relation_merger(create_multi_spans): + doc = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( + text="This text, however, is about nothing (see here)." + ) + doc.labeled_partitions.append(LabeledSpan(start=0, end=48, label="sentence")) + assert str(doc.labeled_partitions[0]) == "This text, however, is about nothing (see here)." + doc.labeled_spans.extend( + [ + LabeledSpan(start=0, end=9, label="claim"), + LabeledSpan(start=11, end=18, label="other"), + LabeledSpan(start=20, end=36, label="claim"), + LabeledSpan(start=38, end=46, label="data"), + ] + ) + assert str(doc.labeled_spans[0]) == "This text" + assert str(doc.labeled_spans[1]) == "however" + assert str(doc.labeled_spans[2]) == "is about nothing" + assert str(doc.labeled_spans[3]) == "see here" + doc.binary_relations.extend( + [ + BinaryRelation(head=doc.labeled_spans[0], tail=doc.labeled_spans[2], label="link"), + BinaryRelation(head=doc.labeled_spans[3], tail=doc.labeled_spans[2], label="support"), + ] + ) + # after merging, that should be the same as in the gold data + doc.binary_relations.predictions.extend( + [ + BinaryRelation(head=doc.labeled_spans[0], tail=doc.labeled_spans[2], label="link"), + BinaryRelation(head=doc.labeled_spans[3], tail=doc.labeled_spans[0], label="support"), + ] + ) + + processor = SpansViaRelationMerger( + relation_layer="binary_relations", + link_relation_label="link", + use_predicted_spans=False, + create_multi_spans=create_multi_spans, + result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions + if create_multi_spans + else None, + result_field_mapping={ + "labeled_spans": "labeled_multi_spans" if create_multi_spans else "labeled_spans", + "labeled_partitions": "labeled_partitions", + }, + ) + result = processor(doc) + if create_multi_spans: + assert isinstance( + result, TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions + ) + sorted_spans = sort_spans(result.labeled_multi_spans) + sorted_spans_resolved = resolve_spans(sorted_spans) + assert sorted_spans_resolved == [ + (("This text", "is about nothing"), "claim"), + (("however",), "other"), + (("see here",), "data"), + ] + else: + assert isinstance(result, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions) + sorted_spans = sort_spans(result.labeled_spans) + sorted_spans_resolved = resolve_spans(sorted_spans) + assert sorted_spans_resolved == [ + ("This text, however, is about nothing", "claim"), + ("however", "other"), + ("see here", "data"), + ] + # check gold and predicted relations + for relations in [result.binary_relations, result.binary_relations.predictions]: + assert len(relations) == 1 + assert relations[0].head == sorted_spans[2] + assert relations[0].tail == sorted_spans[0] + assert relations[0].label == "support" + + # check the labeled partitions + assert len(result.labeled_partitions) == 1 + assert str(result.labeled_partitions[0]) == "This text, however, is about nothing (see here)." + + +def test_spans_via_relation_merger_create_multi_span_missing_result_document_type(): + with pytest.raises(ValueError) as exc_info: + SpansViaRelationMerger( + relation_layer="binary_relations", + link_relation_label="link", + create_multi_spans=True, + ) + assert ( + str(exc_info.value) == "result_document_type must be set when create_multi_spans is True" + ) + + +def test_spans_via_relation_merger_with_result_document_type_missing_result_field_mapping(): + with pytest.raises(ValueError) as exc_info: + SpansViaRelationMerger( + relation_layer="binary_relations", + link_relation_label="link", + result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, + ) + assert ( + str(exc_info.value) + == "result_field_mapping must be set when result_document_type is provided" + )