From 3dcbb4c75acd44d96320eac480c0704a4bfb3fbf Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Fri, 28 Jul 2023 15:14:24 +0200 Subject: [PATCH] implement target getters for BaseAnnotationList (#297) * implement target_layers and target_fields properties for BaseAnnotationList * rename target_fields to targets; remove hasattr check; add target and target_layer * add tests for targets, target, target_layers, and target_layer --- src/pytorch_ie/core/document.py | 33 ++++++++++++++++++++ tests/test_document.py | 53 +++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/src/pytorch_ie/core/document.py b/src/pytorch_ie/core/document.py index 4c61a416..11e144a2 100644 --- a/src/pytorch_ie/core/document.py +++ b/src/pytorch_ie/core/document.py @@ -351,6 +351,39 @@ def pop(self, index: int = -1) -> T: ann.set_targets(None) return ann + @property + def targets(self) -> dict[str, Any]: + return { + target_field_name: getattr(self._document, target_field_name) + for target_field_name in self._targets + } + + @property + def target(self) -> Any: + tgts = self.targets + if len(tgts) != 1: + raise ValueError( + f"The annotation layer has more or less than one target: {self._targets}" + ) + return list(tgts.values())[0] + + @property + def target_layers(self) -> dict[str, "AnnotationList"]: + return { + target_name: target + for target_name, target in self.targets.items() + if isinstance(target, AnnotationList) + } + + @property + def target_layer(self) -> "AnnotationList": + tgt_layers = self.target_layers + if len(tgt_layers) != 1: + raise ValueError( + f"The annotation layer has more or less than one target layer: {list(tgt_layers.keys())}" + ) + return list(tgt_layers.values())[0] + class AnnotationList(BaseAnnotationList[T]): def __init__(self, document: "Document", targets: List["str"]): diff --git a/tests/test_document.py b/tests/test_document.py index 50f6d99a..18cf1885 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -504,3 +504,56 @@ class TestDocument(Document): ), ): doc = TestDocument(text="text1") + + +def test_annotation_list_targets(): + @dataclasses.dataclass + class TestDocument(Document): + text: str + entities1: AnnotationList[LabeledSpan] = annotation_field(target="text") + entities2: AnnotationList[LabeledSpan] = annotation_field(target="text") + relations1: AnnotationList[BinaryRelation] = annotation_field(target="entities1") + relations2: AnnotationList[BinaryRelation] = annotation_field( + targets=["entities1", "entities2"] + ) + + doc = TestDocument(text="text1") + + # test getting all targets + assert doc.entities1.targets == {"text": doc.text} + assert doc.entities2.targets == {"text": doc.text} + assert doc.relations1.targets == {"entities1": doc.entities1} + assert doc.relations2.targets == {"entities1": doc.entities1, "entities2": doc.entities2} + + # test getting a single target + assert doc.entities1.target == doc.text + assert doc.entities2.target == doc.text + assert doc.relations1.target == doc.entities1 + # check that the target of relations2 is not set because it has more than one target + with pytest.raises(ValueError) as excinfo: + doc.relations2.target + assert ( + str(excinfo.value) + == "The annotation layer has more or less than one target: ['entities1', 'entities2']" + ) + + # test getting all target layers + assert doc.entities1.target_layers == {} + assert doc.entities2.target_layers == {} + assert doc.relations1.target_layers == {"entities1": doc.entities1} + assert doc.relations2.target_layers == {"entities1": doc.entities1, "entities2": doc.entities2} + + # test getting a single target layer + with pytest.raises(ValueError) as excinfo: + doc.entities1.target_layer + assert str(excinfo.value) == "The annotation layer has more or less than one target layer: []" + with pytest.raises(ValueError) as excinfo: + doc.entities2.target_layer + assert str(excinfo.value) == "The annotation layer has more or less than one target layer: []" + assert doc.relations1.target_layer == doc.entities1 + with pytest.raises(ValueError) as excinfo: + doc.relations2.target_layer + assert ( + str(excinfo.value) + == "The annotation layer has more or less than one target layer: ['entities1', 'entities2']" + )