Skip to content

Commit

Permalink
implement target getters for BaseAnnotationList (#297)
Browse files Browse the repository at this point in the history
* implement target_layers and target_fields properties for BaseAnnotationList

* rename target_fields to targets; remove hasattr check; add target and target_layer

* add tests for targets, target, target_layers, and target_layer
  • Loading branch information
ArneBinder authored Jul 28, 2023
1 parent 92c47ca commit 3dcbb4c
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
33 changes: 33 additions & 0 deletions src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,39 @@ def pop(self, index: int = -1) -> T:
ann.set_targets(None)
return ann

@property
def targets(self) -> dict[str, Any]:
return {
target_field_name: getattr(self._document, target_field_name)
for target_field_name in self._targets
}

@property
def target(self) -> Any:
tgts = self.targets
if len(tgts) != 1:
raise ValueError(
f"The annotation layer has more or less than one target: {self._targets}"
)
return list(tgts.values())[0]

@property
def target_layers(self) -> dict[str, "AnnotationList"]:
return {
target_name: target
for target_name, target in self.targets.items()
if isinstance(target, AnnotationList)
}

@property
def target_layer(self) -> "AnnotationList":
tgt_layers = self.target_layers
if len(tgt_layers) != 1:
raise ValueError(
f"The annotation layer has more or less than one target layer: {list(tgt_layers.keys())}"
)
return list(tgt_layers.values())[0]


class AnnotationList(BaseAnnotationList[T]):
def __init__(self, document: "Document", targets: List["str"]):
Expand Down
53 changes: 53 additions & 0 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,56 @@ class TestDocument(Document):
),
):
doc = TestDocument(text="text1")


def test_annotation_list_targets():
@dataclasses.dataclass
class TestDocument(Document):
text: str
entities1: AnnotationList[LabeledSpan] = annotation_field(target="text")
entities2: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations1: AnnotationList[BinaryRelation] = annotation_field(target="entities1")
relations2: AnnotationList[BinaryRelation] = annotation_field(
targets=["entities1", "entities2"]
)

doc = TestDocument(text="text1")

# test getting all targets
assert doc.entities1.targets == {"text": doc.text}
assert doc.entities2.targets == {"text": doc.text}
assert doc.relations1.targets == {"entities1": doc.entities1}
assert doc.relations2.targets == {"entities1": doc.entities1, "entities2": doc.entities2}

# test getting a single target
assert doc.entities1.target == doc.text
assert doc.entities2.target == doc.text
assert doc.relations1.target == doc.entities1
# check that the target of relations2 is not set because it has more than one target
with pytest.raises(ValueError) as excinfo:
doc.relations2.target
assert (
str(excinfo.value)
== "The annotation layer has more or less than one target: ['entities1', 'entities2']"
)

# test getting all target layers
assert doc.entities1.target_layers == {}
assert doc.entities2.target_layers == {}
assert doc.relations1.target_layers == {"entities1": doc.entities1}
assert doc.relations2.target_layers == {"entities1": doc.entities1, "entities2": doc.entities2}

# test getting a single target layer
with pytest.raises(ValueError) as excinfo:
doc.entities1.target_layer
assert str(excinfo.value) == "The annotation layer has more or less than one target layer: []"
with pytest.raises(ValueError) as excinfo:
doc.entities2.target_layer
assert str(excinfo.value) == "The annotation layer has more or less than one target layer: []"
assert doc.relations1.target_layer == doc.entities1
with pytest.raises(ValueError) as excinfo:
doc.relations2.target_layer
assert (
str(excinfo.value)
== "The annotation layer has more or less than one target layer: ['entities1', 'entities2']"
)

0 comments on commit 3dcbb4c

Please sign in to comment.