Skip to content

Commit

Permalink
cleanup documents and annotations (#134)
Browse files Browse the repository at this point in the history
* cleanup documents and annotations

* remove conversion.py
  • Loading branch information
ArneBinder authored Sep 19, 2023
1 parent 689fa3c commit df0f950
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 111 deletions.
57 changes: 0 additions & 57 deletions src/document/conversion.py

This file was deleted.

76 changes: 22 additions & 54 deletions src/document/types.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,42 @@
import dataclasses
from typing import Optional

from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span, _post_init_single_label
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.core import Annotation, AnnotationList, annotation_field
from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument
from pytorch_ie.documents import TextBasedDocument, TextDocumentWithLabeledEntitiesAndRelations

# =========================== Annotation Types ============================= #


@dataclasses.dataclass(eq=True, frozen=True)
class Attribute(Annotation):
target_annotation: Annotation
annotation: Annotation
label: str
value: Optional[str] = None
score: float = 1.0
type: Optional[str] = None
score: Optional[float] = dataclasses.field(default=None, compare=False)

def __post_init__(self) -> None:
_post_init_single_label(self)
if not isinstance(self.label, str):
raise ValueError("label must be a single string.")
if not (self.score is None or isinstance(self.score, float)):
raise ValueError("score must be a single float.")

def __str__(self) -> str:
if self.target is not None:
result = f"label={self.label},annotation={self.annotation}"
else:
result = f"label={self.label}"
if self.type is not None:
result += f",type={self.type}"
if self.score is not None:
result += f",score={self.score}"
return f"{self.__class__.__name__}({result})"


# ============================= Document Types ============================= #


@dataclasses.dataclass
class TokenDocumentWithEntitiesAndRelations(TokenBasedDocument):
entities: AnnotationList[Span] = annotation_field(target="tokens")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")


@dataclasses.dataclass
class TokenDocumentWithLabeledEntitiesAndRelations(TokenBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")


@dataclasses.dataclass
class TextDocumentWithEntityMentions(TextBasedDocument):
entity_mentions: AnnotationList[LabeledSpan] = annotation_field(target="text")


@dataclasses.dataclass
class TextDocumentWithEntitiesAndRelations(TextBasedDocument):
"""Possible input class for TransformerRETextClassificationTaskModule."""

entities: AnnotationList[Span] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")


@dataclasses.dataclass
class TextDocumentWithLabeledEntitiesAndRelations(TextBasedDocument):
"""Possible input class for TransformerRETextClassificationTaskModule."""

class TextDocumentWithLabeledEntitiesAndEntityAttributes(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")


@dataclasses.dataclass
class DocumentWithEntitiesRelationsAndLabeledPartitions(TextBasedDocument):
"""Possible input class for TransformerRETextClassificationTaskModule."""

entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
partitions: AnnotationList[LabeledSpan] = annotation_field(target="text")


@dataclasses.dataclass
class BratDocument(TextBasedDocument):
"""Possible input class for TransformerRETextClassificationTaskModule."""

spans: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="spans")
span_attributions: AnnotationList[Attribute] = annotation_field(target="spans")
relation_attributions: AnnotationList[Attribute] = annotation_field(target="relations")
entity_attributes: AnnotationList[Attribute] = annotation_field(target="entities")
26 changes: 26 additions & 0 deletions tests/unit/document/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pytorch_ie.annotations import LabeledSpan

from src.document.types import Attribute, TextDocumentWithLabeledEntitiesAndEntityAttributes


def test_attribute():
entity = LabeledSpan(start=0, end=1, label="PER")
attribute = Attribute(annotation=entity, label="FACT")

assert attribute.annotation == entity
assert attribute.label == "FACT"
assert attribute.type is None
assert attribute.score is None

assert str(attribute) == "Attribute(label=FACT)"


def test_document():
doc = TextDocumentWithLabeledEntitiesAndEntityAttributes("He is really a person.")
entity = LabeledSpan(start=0, end=2, label="PER")
doc.entities.append(entity)
attribute = Attribute(annotation=entity, label="FACT")
doc.entity_attributes.append(attribute)

assert str(doc.entities[0]) == "He"
assert str(doc.entity_attributes[0]) == "Attribute(label=FACT,annotation=He)"

0 comments on commit df0f950

Please sign in to comment.