From bee1523789bd22f45f43606a981bccef1d9ba1a9 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 2 Feb 2024 17:48:13 +0100 Subject: [PATCH 1/8] implement document.processing.SpansViaRelationMerger --- poetry.lock | 2 +- pyproject.toml | 2 + .../document/processing/__init__.py | 1 + .../processing/merge_spans_via_relation.py | 166 ++++++++++++++++++ 4 files changed, 170 insertions(+), 1 deletion(-) create mode 100644 src/pie_modules/document/processing/merge_spans_via_relation.py diff --git a/poetry.lock b/poetry.lock index b1be5fc4d..5c190fca9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1943,4 +1943,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "e6a26cb8343b9be100a3d3a66471077b8abc5df470a33d05f33450b2d0e0a35b" +content-hash = "aecdcdc068d62a0c1630df94a488c481242188d73028084e81a1abc83357b762" diff --git a/pyproject.toml b/pyproject.toml index 282ea1610..47c1b1463 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,8 @@ pytorch-ie = ">=0.29.8,<0.30.0" pytorch-lightning = "^2.1.0" torchmetrics = "^1" pytorch-crf = ">=0.7.2" +# for SpansViaRelationMerger +networkx = "^3.0.0" # because of BartModelWithDecoderPositionIds transformers = "^4.35.0" diff --git a/src/pie_modules/document/processing/__init__.py b/src/pie_modules/document/processing/__init__.py index f10a62d53..5f3c49e9b 100644 --- a/src/pie_modules/document/processing/__init__.py +++ b/src/pie_modules/document/processing/__init__.py @@ -1,3 +1,4 @@ +from .merge_spans_via_relation import SpansViaRelationMerger from .regex_partitioner import RegexPartitioner from .relation_argument_sorter import RelationArgumentSorter from .text_span_trimmer import TextSpanTrimmer diff --git a/src/pie_modules/document/processing/merge_spans_via_relation.py b/src/pie_modules/document/processing/merge_spans_via_relation.py new file mode 100644 index 000000000..d9425b2ee --- /dev/null +++ b/src/pie_modules/document/processing/merge_spans_via_relation.py @@ -0,0 +1,166 @@ +import logging +from typing import Sequence, Set, Tuple, TypeVar, Union + +import networkx as nx +from pytorch_ie import AnnotationLayer +from pytorch_ie.core import Document + +from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan + +logger = logging.getLogger(__name__) + + +D = TypeVar("D", bound=Document) + + +def _merge_spans_via_relation( + spans: Sequence[LabeledSpan], + relations: Sequence[BinaryRelation], + link_relation_label: str, + create_multi_spans: bool = True, +) -> Tuple[Union[Set[LabeledSpan], Set[LabeledMultiSpan]], Set[BinaryRelation]]: + # convert list of relations to a graph to easily calculate connected components to merge + g = nx.Graph() + link_relations = [] + other_relations = [] + for rel in relations: + if rel.label == link_relation_label: + link_relations.append(rel) + # never merge spans that have not the same label + if ( + not (isinstance(rel.head, LabeledSpan) or isinstance(rel.tail, LabeledSpan)) + or rel.head.label == rel.tail.label + ): + g.add_edge(rel.head, rel.tail) + else: + logger.debug( + f"spans to merge do not have the same label, do not merge them: {rel.head}, {rel.tail}" + ) + else: + other_relations.append(rel) + + span_mapping = {} + connected_components: Set[LabeledSpan] + for connected_components in nx.connected_components(g): + # all spans in a connected component have the same label + label = list(span.label for span in connected_components)[0] + connected_components_sorted = sorted(connected_components, key=lambda span: span.start) + if create_multi_spans: + new_span = LabeledMultiSpan( + slices=tuple((span.start, span.end) for span in connected_components_sorted), + label=label, + ) + else: + new_span = LabeledSpan( + start=min(span.start for span in connected_components_sorted), + end=max(span.end for span in connected_components_sorted), + label=label, + ) + for span in connected_components_sorted: + span_mapping[span] = new_span + for span in spans: + if span not in span_mapping: + if create_multi_spans: + span_mapping[span] = LabeledMultiSpan( + slices=((span.start, span.end),), label=span.label, score=span.score + ) + else: + span_mapping[span] = LabeledSpan( + start=span.start, end=span.end, label=span.label, score=span.score + ) + + new_spans = set(span_mapping.values()) + new_relations = { + BinaryRelation( + head=span_mapping[rel.head], + tail=span_mapping[rel.tail], + label=rel.label, + score=rel.score, + ) + for rel in other_relations + } + + return new_spans, new_relations + + +class SpansViaRelationMerger: + """Merge spans based on relations. + + This processor merges spans based on the relations with a specific label. The spans + are merged into a single span if they are connected via a relation with the specified + label. The processor can be used to merge spans based on predicted or gold relations. + + Args: + relation_layer: The name of the relation layer in the document. + link_relation_label: The label of the relation that should be used to merge spans. + result_document_type: The type of the document to return. + result_field_mapping: A mapping from the field names in the input document to the + field names in the result document. + create_multi_spans: Whether to create multi spans or not. If `True`, multi spans + will be created, otherwise single spans that cover the merged spans will be + created. + use_predicted_spans: Whether to use the predicted spans or the gold spans. + process_predictions: Whether to process the predictions or not. If `True`, the + predictions will be processed, otherwise only the gold annotations will be + processed. + """ + + def __init__( + self, + relation_layer: str, + link_relation_label: str, + result_document_type: type[Document], + result_field_mapping: dict[str, str], + create_multi_spans: bool = True, + use_predicted_spans: bool = False, + process_predictions: bool = True, + ): + self.relation_layer = relation_layer + self.link_relation_label = link_relation_label + self.result_document_type = result_document_type + self.result_field_mapping = result_field_mapping + self.create_multi_spans = create_multi_spans + self.use_predicted_spans = use_predicted_spans + self.process_predictions = process_predictions + + def __call__(self, document: D) -> D: + relations: AnnotationLayer[BinaryRelation] = document[self.relation_layer] + spans: AnnotationLayer[LabeledSpan] = document[self.relation_layer].target_layer + + new_gold_spans, new_gold_relations = _merge_spans_via_relation( + spans=spans, + relations=relations, + link_relation_label=self.link_relation_label, + create_multi_spans=self.create_multi_spans, + ) + if self.process_predictions: + new_pred_spans, new_pred_relations = _merge_spans_via_relation( + spans=spans.predictions if self.use_predicted_spans else spans, + relations=relations.predictions, + link_relation_label=self.link_relation_label, + create_multi_spans=self.create_multi_spans, + ) + else: + if self.use_predicted_spans: + raise ValueError("cannot use predicted spans without processing predictions") + new_pred_spans = set(spans.predictions.clear()) + new_pred_relations = set(relations.predictions.clear()) + + result = document.copy(with_annotations=False).as_type(new_type=self.result_document_type) + span_layer_name = document[self.relation_layer].target_name + result_span_layer_name = self.result_field_mapping[span_layer_name] + result_relation_layer_name = self.result_field_mapping[self.relation_layer] + result[result_span_layer_name].extend(new_gold_spans) + result[result_relation_layer_name].extend(new_gold_relations) + result[result_span_layer_name].predictions.extend(new_pred_spans) + result[result_relation_layer_name].predictions.extend(new_pred_relations) + + # copy over remaining fields mentioned in result_field_mapping + for field_name, result_field_name in self.result_field_mapping.items(): + if field_name not in [span_layer_name, self.relation_layer]: + for ann in document[field_name]: + result[result_field_name].append(ann.copy()) + if self.process_predictions: + for ann in document[field_name].predictions: + result[result_field_name].predictions.append(ann.copy()) + return result From c05b4118ac05ffda8e655e6f567a0f628f7ef731 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 5 Feb 2024 12:28:15 +0100 Subject: [PATCH 2/8] allow string for result_document_type --- .../document/processing/merge_spans_via_relation.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/pie_modules/document/processing/merge_spans_via_relation.py b/src/pie_modules/document/processing/merge_spans_via_relation.py index d9425b2ee..3bb93b13e 100644 --- a/src/pie_modules/document/processing/merge_spans_via_relation.py +++ b/src/pie_modules/document/processing/merge_spans_via_relation.py @@ -6,6 +6,7 @@ from pytorch_ie.core import Document from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan +from pie_modules.utils import resolve_type logger = logging.getLogger(__name__) @@ -93,7 +94,9 @@ class SpansViaRelationMerger: Args: relation_layer: The name of the relation layer in the document. link_relation_label: The label of the relation that should be used to merge spans. - result_document_type: The type of the document to return. + result_document_type: The type of the document to return. This can be a class or + a string that can be resolved to a class. The class must be a subclass of + `Document`. result_field_mapping: A mapping from the field names in the input document to the field names in the result document. create_multi_spans: Whether to create multi spans or not. If `True`, multi spans @@ -109,7 +112,7 @@ def __init__( self, relation_layer: str, link_relation_label: str, - result_document_type: type[Document], + result_document_type: Union[type[Document], str], result_field_mapping: dict[str, str], create_multi_spans: bool = True, use_predicted_spans: bool = False, @@ -117,7 +120,9 @@ def __init__( ): self.relation_layer = relation_layer self.link_relation_label = link_relation_label - self.result_document_type = result_document_type + self.result_document_type = resolve_type( + result_document_type, expected_super_type=Document + ) self.result_field_mapping = result_field_mapping self.create_multi_spans = create_multi_spans self.use_predicted_spans = use_predicted_spans From 58c4b5d564cf4eea797732fba06cfca9dd84a8bb Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 5 Feb 2024 13:01:42 +0100 Subject: [PATCH 3/8] add test_merge_spans_via_relation() --- .../test_merge_spans_via_relation.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 tests/document/processing/test_merge_spans_via_relation.py diff --git a/tests/document/processing/test_merge_spans_via_relation.py b/tests/document/processing/test_merge_spans_via_relation.py new file mode 100644 index 000000000..09beb34f0 --- /dev/null +++ b/tests/document/processing/test_merge_spans_via_relation.py @@ -0,0 +1,65 @@ +import pytest + +from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan +from pie_modules.document.processing.merge_spans_via_relation import ( + _merge_spans_via_relation, +) + + +@pytest.mark.parametrize( + "create_multi_spans", + [False, True], +) +def test_merge_spans_via_relation(create_multi_spans: bool): + # we have 6 spans and 4 relations + # spans 0, 2, 4 are connected via "link" relation, so they should be merged + # spans 3, 5 are connected via "link" relation, but they do not have the same label, + # so they should not be merged. But the relation should be removed + # spans 0, 3 are connected via "relation_x" relation, its head should be remapped to the new span + spans = [ + LabeledSpan(start=0, end=1, label="label_a"), + LabeledSpan(start=2, end=3, label="other"), + LabeledSpan(start=4, end=5, label="label_a"), + LabeledSpan(start=6, end=7, label="label_b"), + LabeledSpan(start=8, end=9, label="label_a"), + LabeledSpan(start=10, end=11, label="label_c"), + ] + relations = [ + BinaryRelation(head=spans[0], tail=spans[2], label="link"), + BinaryRelation(head=spans[0], tail=spans[3], label="relation_x"), + BinaryRelation(head=spans[2], tail=spans[4], label="link"), + BinaryRelation(head=spans[3], tail=spans[5], label="link"), + ] + + merged_spans, merged_relations = _merge_spans_via_relation( + spans=spans, + relations=relations, + link_relation_label="link", + create_multi_spans=create_multi_spans, + ) + if create_multi_spans: + head = LabeledMultiSpan( + slices=( + (0, 1), + (4, 5), + (8, 9), + ), + label="label_a", + ) + tail = LabeledMultiSpan(slices=((6, 7),), label="label_b") + assert merged_spans == { + head, + LabeledMultiSpan(slices=((2, 3),), label="other"), + tail, + LabeledMultiSpan(slices=((10, 11),), label="label_c"), + } + else: + head = LabeledSpan(start=0, end=9, label="label_a") + tail = LabeledSpan(start=6, end=7, label="label_b") + assert merged_spans == { + head, + tail, + LabeledSpan(start=2, end=3, label="other"), + LabeledSpan(start=10, end=11, label="label_c"), + } + assert merged_relations == {BinaryRelation(head=head, tail=tail, label="relation_x")} From 5bee473e44d819d5ef45757ddd8cc613d9f93fe1 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 5 Feb 2024 13:20:15 +0100 Subject: [PATCH 4/8] allow use_predicted_spans to be True when process_predictions is False --- .../document/processing/merge_spans_via_relation.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/pie_modules/document/processing/merge_spans_via_relation.py b/src/pie_modules/document/processing/merge_spans_via_relation.py index 3bb93b13e..5dc8f0647 100644 --- a/src/pie_modules/document/processing/merge_spans_via_relation.py +++ b/src/pie_modules/document/processing/merge_spans_via_relation.py @@ -102,10 +102,11 @@ class SpansViaRelationMerger: create_multi_spans: Whether to create multi spans or not. If `True`, multi spans will be created, otherwise single spans that cover the merged spans will be created. - use_predicted_spans: Whether to use the predicted spans or the gold spans. process_predictions: Whether to process the predictions or not. If `True`, the predictions will be processed, otherwise only the gold annotations will be processed. + use_predicted_spans: Whether to use the predicted spans or the gold spans when + processing predictions. """ def __init__( @@ -115,8 +116,8 @@ def __init__( result_document_type: Union[type[Document], str], result_field_mapping: dict[str, str], create_multi_spans: bool = True, - use_predicted_spans: bool = False, process_predictions: bool = True, + use_predicted_spans: bool = True, ): self.relation_layer = relation_layer self.link_relation_label = link_relation_label @@ -146,8 +147,6 @@ def __call__(self, document: D) -> D: create_multi_spans=self.create_multi_spans, ) else: - if self.use_predicted_spans: - raise ValueError("cannot use predicted spans without processing predictions") new_pred_spans = set(spans.predictions.clear()) new_pred_relations = set(relations.predictions.clear()) From 834b97ce5dd0a3bb8535ff98272203c57cce6c2a Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 5 Feb 2024 13:23:33 +0100 Subject: [PATCH 5/8] simplify: remove process_predictions parameter --- .../processing/merge_spans_via_relation.py | 29 +++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/src/pie_modules/document/processing/merge_spans_via_relation.py b/src/pie_modules/document/processing/merge_spans_via_relation.py index 5dc8f0647..c10aaf257 100644 --- a/src/pie_modules/document/processing/merge_spans_via_relation.py +++ b/src/pie_modules/document/processing/merge_spans_via_relation.py @@ -102,9 +102,6 @@ class SpansViaRelationMerger: create_multi_spans: Whether to create multi spans or not. If `True`, multi spans will be created, otherwise single spans that cover the merged spans will be created. - process_predictions: Whether to process the predictions or not. If `True`, the - predictions will be processed, otherwise only the gold annotations will be - processed. use_predicted_spans: Whether to use the predicted spans or the gold spans when processing predictions. """ @@ -116,7 +113,6 @@ def __init__( result_document_type: Union[type[Document], str], result_field_mapping: dict[str, str], create_multi_spans: bool = True, - process_predictions: bool = True, use_predicted_spans: bool = True, ): self.relation_layer = relation_layer @@ -127,28 +123,26 @@ def __init__( self.result_field_mapping = result_field_mapping self.create_multi_spans = create_multi_spans self.use_predicted_spans = use_predicted_spans - self.process_predictions = process_predictions def __call__(self, document: D) -> D: relations: AnnotationLayer[BinaryRelation] = document[self.relation_layer] spans: AnnotationLayer[LabeledSpan] = document[self.relation_layer].target_layer + # process gold annotations new_gold_spans, new_gold_relations = _merge_spans_via_relation( spans=spans, relations=relations, link_relation_label=self.link_relation_label, create_multi_spans=self.create_multi_spans, ) - if self.process_predictions: - new_pred_spans, new_pred_relations = _merge_spans_via_relation( - spans=spans.predictions if self.use_predicted_spans else spans, - relations=relations.predictions, - link_relation_label=self.link_relation_label, - create_multi_spans=self.create_multi_spans, - ) - else: - new_pred_spans = set(spans.predictions.clear()) - new_pred_relations = set(relations.predictions.clear()) + + # process predicted annotations + new_pred_spans, new_pred_relations = _merge_spans_via_relation( + spans=spans.predictions if self.use_predicted_spans else spans, + relations=relations.predictions, + link_relation_label=self.link_relation_label, + create_multi_spans=self.create_multi_spans, + ) result = document.copy(with_annotations=False).as_type(new_type=self.result_document_type) span_layer_name = document[self.relation_layer].target_name @@ -164,7 +158,6 @@ def __call__(self, document: D) -> D: if field_name not in [span_layer_name, self.relation_layer]: for ann in document[field_name]: result[result_field_name].append(ann.copy()) - if self.process_predictions: - for ann in document[field_name].predictions: - result[result_field_name].predictions.append(ann.copy()) + for ann in document[field_name].predictions: + result[result_field_name].predictions.append(ann.copy()) return result From 6def9755cbe3940de11942e3b33d7f30561f6692 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 5 Feb 2024 15:20:29 +0100 Subject: [PATCH 6/8] add test_spans_via_relation_merger() and variants --- .../processing/merge_spans_via_relation.py | 37 +++-- .../test_merge_spans_via_relation.py | 134 ++++++++++++++++++ 2 files changed, 161 insertions(+), 10 deletions(-) diff --git a/src/pie_modules/document/processing/merge_spans_via_relation.py b/src/pie_modules/document/processing/merge_spans_via_relation.py index c10aaf257..3ca60f2bb 100644 --- a/src/pie_modules/document/processing/merge_spans_via_relation.py +++ b/src/pie_modules/document/processing/merge_spans_via_relation.py @@ -1,5 +1,5 @@ import logging -from typing import Sequence, Set, Tuple, TypeVar, Union +from typing import Optional, Sequence, Set, Tuple, TypeVar, Union import networkx as nx from pytorch_ie import AnnotationLayer @@ -110,18 +110,31 @@ def __init__( self, relation_layer: str, link_relation_label: str, - result_document_type: Union[type[Document], str], - result_field_mapping: dict[str, str], + result_document_type: Optional[Union[type[Document], str]] = None, + result_field_mapping: Optional[dict[str, str]] = None, create_multi_spans: bool = True, use_predicted_spans: bool = True, ): self.relation_layer = relation_layer self.link_relation_label = link_relation_label - self.result_document_type = resolve_type( - result_document_type, expected_super_type=Document - ) - self.result_field_mapping = result_field_mapping self.create_multi_spans = create_multi_spans + if self.create_multi_spans: + if result_document_type is None: + raise ValueError( + "result_document_type must be set when create_multi_spans is True" + ) + self.result_document_type: Optional[type[Document]] + if result_document_type is not None: + if result_field_mapping is None: + raise ValueError( + "result_field_mapping must be set when result_document_type is provided" + ) + self.result_document_type = resolve_type( + result_document_type, expected_super_type=Document + ) + else: + self.result_document_type = None + self.result_field_mapping = result_field_mapping or {} self.use_predicted_spans = use_predicted_spans def __call__(self, document: D) -> D: @@ -144,10 +157,14 @@ def __call__(self, document: D) -> D: create_multi_spans=self.create_multi_spans, ) - result = document.copy(with_annotations=False).as_type(new_type=self.result_document_type) + result = document.copy(with_annotations=False) + if self.result_document_type is not None: + result = result.as_type(new_type=self.result_document_type) span_layer_name = document[self.relation_layer].target_name - result_span_layer_name = self.result_field_mapping[span_layer_name] - result_relation_layer_name = self.result_field_mapping[self.relation_layer] + result_span_layer_name = self.result_field_mapping.get(span_layer_name, span_layer_name) + result_relation_layer_name = self.result_field_mapping.get( + self.relation_layer, self.relation_layer + ) result[result_span_layer_name].extend(new_gold_spans) result[result_relation_layer_name].extend(new_gold_relations) result[result_span_layer_name].predictions.extend(new_pred_spans) diff --git a/tests/document/processing/test_merge_spans_via_relation.py b/tests/document/processing/test_merge_spans_via_relation.py index 09beb34f0..ec2ccd56d 100644 --- a/tests/document/processing/test_merge_spans_via_relation.py +++ b/tests/document/processing/test_merge_spans_via_relation.py @@ -1,9 +1,18 @@ import pytest +from pytorch_ie.documents import ( + TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, +) from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan +from pie_modules.document.processing import SpansViaRelationMerger from pie_modules.document.processing.merge_spans_via_relation import ( _merge_spans_via_relation, ) +from pie_modules.documents import ( + TextDocumentWithLabeledMultiSpansAndBinaryRelations, + TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, + TextDocumentWithLabeledSpansAndBinaryRelations, +) @pytest.mark.parametrize( @@ -63,3 +72,128 @@ def test_merge_spans_via_relation(create_multi_spans: bool): LabeledSpan(start=10, end=11, label="label_c"), } assert merged_relations == {BinaryRelation(head=head, tail=tail, label="relation_x")} + + +def sort_spans(spans): + if len(spans) == 0: + return [] + if isinstance(spans[0], LabeledSpan): + return sorted(spans, key=lambda span: (span.start, span.end, span.label)) + else: + return sorted(spans, key=lambda span: (span.slices, span.label)) + + +def resolve_spans(spans): + if len(spans) == 0: + return [] + if isinstance(spans[0], LabeledSpan): + return [(span.target[span.start : span.end], span.label) for span in spans] + else: + return [ + (tuple(span.target[start:end] for start, end in span.slices), span.label) + for span in spans + ] + + +@pytest.mark.parametrize("create_multi_spans", [False, True]) +def test_spans_via_relation_merger(create_multi_spans): + doc = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( + text="This text, however, is about nothing (see here)." + ) + doc.labeled_partitions.append(LabeledSpan(start=0, end=48, label="sentence")) + assert str(doc.labeled_partitions[0]) == "This text, however, is about nothing (see here)." + doc.labeled_spans.extend( + [ + LabeledSpan(start=0, end=9, label="claim"), + LabeledSpan(start=11, end=18, label="other"), + LabeledSpan(start=20, end=36, label="claim"), + LabeledSpan(start=38, end=46, label="data"), + ] + ) + assert str(doc.labeled_spans[0]) == "This text" + assert str(doc.labeled_spans[1]) == "however" + assert str(doc.labeled_spans[2]) == "is about nothing" + assert str(doc.labeled_spans[3]) == "see here" + doc.binary_relations.extend( + [ + BinaryRelation(head=doc.labeled_spans[0], tail=doc.labeled_spans[2], label="link"), + BinaryRelation(head=doc.labeled_spans[3], tail=doc.labeled_spans[2], label="support"), + ] + ) + # after merging, that should be the same as in the gold data + doc.binary_relations.predictions.extend( + [ + BinaryRelation(head=doc.labeled_spans[0], tail=doc.labeled_spans[2], label="link"), + BinaryRelation(head=doc.labeled_spans[3], tail=doc.labeled_spans[0], label="support"), + ] + ) + + processor = SpansViaRelationMerger( + relation_layer="binary_relations", + link_relation_label="link", + use_predicted_spans=False, + create_multi_spans=create_multi_spans, + result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions + if create_multi_spans + else None, + result_field_mapping={ + "labeled_spans": "labeled_multi_spans" if create_multi_spans else "labeled_spans", + "labeled_partitions": "labeled_partitions", + }, + ) + result = processor(doc) + if create_multi_spans: + assert isinstance( + result, TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions + ) + sorted_spans = sort_spans(result.labeled_multi_spans) + sorted_spans_resolved = resolve_spans(sorted_spans) + assert sorted_spans_resolved == [ + (("This text", "is about nothing"), "claim"), + (("however",), "other"), + (("see here",), "data"), + ] + else: + assert isinstance(result, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions) + sorted_spans = sort_spans(result.labeled_spans) + sorted_spans_resolved = resolve_spans(sorted_spans) + assert sorted_spans_resolved == [ + ("This text, however, is about nothing", "claim"), + ("however", "other"), + ("see here", "data"), + ] + # check gold and predicted relations + for relations in [result.binary_relations, result.binary_relations.predictions]: + assert len(relations) == 1 + assert relations[0].head == sorted_spans[2] + assert relations[0].tail == sorted_spans[0] + assert relations[0].label == "support" + + # check the labeled partitions + assert len(result.labeled_partitions) == 1 + assert str(result.labeled_partitions[0]) == "This text, however, is about nothing (see here)." + + +def test_spans_via_relation_merger_create_multi_span_missing_result_document_type(): + with pytest.raises(ValueError) as exc_info: + SpansViaRelationMerger( + relation_layer="binary_relations", + link_relation_label="link", + create_multi_spans=True, + ) + assert ( + str(exc_info.value) == "result_document_type must be set when create_multi_spans is True" + ) + + +def test_spans_via_relation_merger_with_result_document_type_missing_result_field_mapping(): + with pytest.raises(ValueError) as exc_info: + SpansViaRelationMerger( + relation_layer="binary_relations", + link_relation_label="link", + result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, + ) + assert ( + str(exc_info.value) + == "result_field_mapping must be set when result_document_type is provided" + ) From 01ba1091f0634f5a12f96deb22ca6bb624642b39 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 5 Feb 2024 15:24:33 +0100 Subject: [PATCH 7/8] improve docs --- .../processing/merge_spans_via_relation.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/pie_modules/document/processing/merge_spans_via_relation.py b/src/pie_modules/document/processing/merge_spans_via_relation.py index 3ca60f2bb..45c11582f 100644 --- a/src/pie_modules/document/processing/merge_spans_via_relation.py +++ b/src/pie_modules/document/processing/merge_spans_via_relation.py @@ -89,19 +89,22 @@ class SpansViaRelationMerger: This processor merges spans based on the relations with a specific label. The spans are merged into a single span if they are connected via a relation with the specified - label. The processor can be used to merge spans based on predicted or gold relations. + label. The processor handles both gold and predicted annotations. Args: relation_layer: The name of the relation layer in the document. link_relation_label: The label of the relation that should be used to merge spans. - result_document_type: The type of the document to return. This can be a class or - a string that can be resolved to a class. The class must be a subclass of - `Document`. - result_field_mapping: A mapping from the field names in the input document to the - field names in the result document. create_multi_spans: Whether to create multi spans or not. If `True`, multi spans will be created, otherwise single spans that cover the merged spans will be created. + result_document_type: The type of the document to return. This can be a class or + a string that can be resolved to a class. The class must be a subclass of + `Document`. Required when `create_multi_spans` is `True`. + result_field_mapping: A mapping from the field names in the input document to the + field names in the result document. This is used to copy over fields from the + input document to the result document. The keys are the field names in the + input document and the values are the field names in the result document. + Required when `result_document_type` is provided. use_predicted_spans: Whether to use the predicted spans or the gold spans when processing predictions. """ From 1ec012191d9fe385e1daf5388045585cf83b45f0 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 5 Feb 2024 15:41:26 +0100 Subject: [PATCH 8/8] improve docs --- .../document/processing/merge_spans_via_relation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pie_modules/document/processing/merge_spans_via_relation.py b/src/pie_modules/document/processing/merge_spans_via_relation.py index 45c11582f..efd8f72f9 100644 --- a/src/pie_modules/document/processing/merge_spans_via_relation.py +++ b/src/pie_modules/document/processing/merge_spans_via_relation.py @@ -87,9 +87,9 @@ def _merge_spans_via_relation( class SpansViaRelationMerger: """Merge spans based on relations. - This processor merges spans based on the relations with a specific label. The spans - are merged into a single span if they are connected via a relation with the specified - label. The processor handles both gold and predicted annotations. + This processor merges spans based on binary relations. The spans are merged into a + single span if they are connected via a relation with the specified link label. The + processor handles both gold and predicted annotations. Args: relation_layer: The name of the relation layer in the document.