From df0f95013c37374d82b5d01d4e64795eda7deb82 Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Tue, 19 Sep 2023 16:47:49 +0200 Subject: [PATCH] cleanup documents and annotations (#134) * cleanup documents and annotations * remove conversion.py --- src/document/conversion.py | 57 ----------------------- src/document/types.py | 76 +++++++++---------------------- tests/unit/document/test_types.py | 26 +++++++++++ 3 files changed, 48 insertions(+), 111 deletions(-) delete mode 100644 src/document/conversion.py create mode 100644 tests/unit/document/test_types.py diff --git a/src/document/conversion.py b/src/document/conversion.py deleted file mode 100644 index 0fcfc49..0000000 --- a/src/document/conversion.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -import logging -from copy import deepcopy -from typing import Dict, List, Tuple, Union - -from pytorch_ie.annotations import LabeledSpan, Span - -from src.document.types import ( - TextDocumentWithEntitiesAndRelations, - TextDocumentWithLabeledEntitiesAndRelations, - TokenDocumentWithEntitiesAndRelations, -) - -logger = logging.getLogger(__name__) - - -def token_based_document_with_entities_and_relations_to_text_based( - doc: TokenDocumentWithEntitiesAndRelations, - token_field: str = "tokens", - entity_layer: str = "entities", - token_separator: str = " ", -) -> Union[TextDocumentWithEntitiesAndRelations, TextDocumentWithLabeledEntitiesAndRelations]: - start = 0 - token_offsets: List[Tuple[int, int]] = [] - tokens = getattr(doc, token_field) - for token in tokens: - end = start + len(token) - token_offsets.append((start, end)) - # we add the separator after each token - start = end + len(token_separator) - - text = token_separator.join([token for token in tokens]) - - entity_map: Dict[int, Span] = {} - entities_have_labels = False - for entity in doc[entity_layer]: - char_start = token_offsets[entity.start][0] - char_end = token_offsets[entity.end - 1][1] - char_offset_entity = entity.copy(start=char_start, end=char_end) - if isinstance(entity, LabeledSpan): - entities_have_labels = True - entity_map[entity._id] = char_offset_entity - - if entities_have_labels: - new_doc = TextDocumentWithLabeledEntitiesAndRelations( - text=text, id=doc.id, metadata=deepcopy(doc.metadata) - ) - else: - new_doc = TextDocumentWithEntitiesAndRelations( - text=text, id=doc.id, metadata=deepcopy(doc.metadata) - ) - new_doc.entities.extend(entity_map.values()) - new_doc.add_all_annotations_from_other( - doc, override_annotation_mapping={"entities": entity_map} - ) - return new_doc diff --git a/src/document/types.py b/src/document/types.py index 78c62e1..5d19fbb 100644 --- a/src/document/types.py +++ b/src/document/types.py @@ -1,74 +1,42 @@ import dataclasses from typing import Optional -from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span, _post_init_single_label +from pytorch_ie.annotations import LabeledSpan from pytorch_ie.core import Annotation, AnnotationList, annotation_field -from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument +from pytorch_ie.documents import TextBasedDocument, TextDocumentWithLabeledEntitiesAndRelations # =========================== Annotation Types ============================= # @dataclasses.dataclass(eq=True, frozen=True) class Attribute(Annotation): - target_annotation: Annotation + annotation: Annotation label: str - value: Optional[str] = None - score: float = 1.0 + type: Optional[str] = None + score: Optional[float] = dataclasses.field(default=None, compare=False) def __post_init__(self) -> None: - _post_init_single_label(self) + if not isinstance(self.label, str): + raise ValueError("label must be a single string.") + if not (self.score is None or isinstance(self.score, float)): + raise ValueError("score must be a single float.") + + def __str__(self) -> str: + if self.target is not None: + result = f"label={self.label},annotation={self.annotation}" + else: + result = f"label={self.label}" + if self.type is not None: + result += f",type={self.type}" + if self.score is not None: + result += f",score={self.score}" + return f"{self.__class__.__name__}({result})" # ============================= Document Types ============================= # @dataclasses.dataclass -class TokenDocumentWithEntitiesAndRelations(TokenBasedDocument): - entities: AnnotationList[Span] = annotation_field(target="tokens") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") - - -@dataclasses.dataclass -class TokenDocumentWithLabeledEntitiesAndRelations(TokenBasedDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") - - -@dataclasses.dataclass -class TextDocumentWithEntityMentions(TextBasedDocument): - entity_mentions: AnnotationList[LabeledSpan] = annotation_field(target="text") - - -@dataclasses.dataclass -class TextDocumentWithEntitiesAndRelations(TextBasedDocument): - """Possible input class for TransformerRETextClassificationTaskModule.""" - - entities: AnnotationList[Span] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") - - -@dataclasses.dataclass -class TextDocumentWithLabeledEntitiesAndRelations(TextBasedDocument): - """Possible input class for TransformerRETextClassificationTaskModule.""" - +class TextDocumentWithLabeledEntitiesAndEntityAttributes(TextBasedDocument): entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") - - -@dataclasses.dataclass -class DocumentWithEntitiesRelationsAndLabeledPartitions(TextBasedDocument): - """Possible input class for TransformerRETextClassificationTaskModule.""" - - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") - partitions: AnnotationList[LabeledSpan] = annotation_field(target="text") - - -@dataclasses.dataclass -class BratDocument(TextBasedDocument): - """Possible input class for TransformerRETextClassificationTaskModule.""" - - spans: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="spans") - span_attributions: AnnotationList[Attribute] = annotation_field(target="spans") - relation_attributions: AnnotationList[Attribute] = annotation_field(target="relations") + entity_attributes: AnnotationList[Attribute] = annotation_field(target="entities") diff --git a/tests/unit/document/test_types.py b/tests/unit/document/test_types.py new file mode 100644 index 0000000..e1086ab --- /dev/null +++ b/tests/unit/document/test_types.py @@ -0,0 +1,26 @@ +from pytorch_ie.annotations import LabeledSpan + +from src.document.types import Attribute, TextDocumentWithLabeledEntitiesAndEntityAttributes + + +def test_attribute(): + entity = LabeledSpan(start=0, end=1, label="PER") + attribute = Attribute(annotation=entity, label="FACT") + + assert attribute.annotation == entity + assert attribute.label == "FACT" + assert attribute.type is None + assert attribute.score is None + + assert str(attribute) == "Attribute(label=FACT)" + + +def test_document(): + doc = TextDocumentWithLabeledEntitiesAndEntityAttributes("He is really a person.") + entity = LabeledSpan(start=0, end=2, label="PER") + doc.entities.append(entity) + attribute = Attribute(annotation=entity, label="FACT") + doc.entity_attributes.append(attribute) + + assert str(doc.entities[0]) == "He" + assert str(doc.entity_attributes[0]) == "Attribute(label=FACT,annotation=He)"