-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
cleanup documents and annotations (#134)
* cleanup documents and annotations * remove conversion.py
- Loading branch information
1 parent
689fa3c
commit df0f950
Showing
3 changed files
with
48 additions
and
111 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,74 +1,42 @@ | ||
import dataclasses | ||
from typing import Optional | ||
|
||
from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span, _post_init_single_label | ||
from pytorch_ie.annotations import LabeledSpan | ||
from pytorch_ie.core import Annotation, AnnotationList, annotation_field | ||
from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument | ||
from pytorch_ie.documents import TextBasedDocument, TextDocumentWithLabeledEntitiesAndRelations | ||
|
||
# =========================== Annotation Types ============================= # | ||
|
||
|
||
@dataclasses.dataclass(eq=True, frozen=True) | ||
class Attribute(Annotation): | ||
target_annotation: Annotation | ||
annotation: Annotation | ||
label: str | ||
value: Optional[str] = None | ||
score: float = 1.0 | ||
type: Optional[str] = None | ||
score: Optional[float] = dataclasses.field(default=None, compare=False) | ||
|
||
def __post_init__(self) -> None: | ||
_post_init_single_label(self) | ||
if not isinstance(self.label, str): | ||
raise ValueError("label must be a single string.") | ||
if not (self.score is None or isinstance(self.score, float)): | ||
raise ValueError("score must be a single float.") | ||
|
||
def __str__(self) -> str: | ||
if self.target is not None: | ||
result = f"label={self.label},annotation={self.annotation}" | ||
else: | ||
result = f"label={self.label}" | ||
if self.type is not None: | ||
result += f",type={self.type}" | ||
if self.score is not None: | ||
result += f",score={self.score}" | ||
return f"{self.__class__.__name__}({result})" | ||
|
||
|
||
# ============================= Document Types ============================= # | ||
|
||
|
||
@dataclasses.dataclass | ||
class TokenDocumentWithEntitiesAndRelations(TokenBasedDocument): | ||
entities: AnnotationList[Span] = annotation_field(target="tokens") | ||
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") | ||
|
||
|
||
@dataclasses.dataclass | ||
class TokenDocumentWithLabeledEntitiesAndRelations(TokenBasedDocument): | ||
entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens") | ||
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") | ||
|
||
|
||
@dataclasses.dataclass | ||
class TextDocumentWithEntityMentions(TextBasedDocument): | ||
entity_mentions: AnnotationList[LabeledSpan] = annotation_field(target="text") | ||
|
||
|
||
@dataclasses.dataclass | ||
class TextDocumentWithEntitiesAndRelations(TextBasedDocument): | ||
"""Possible input class for TransformerRETextClassificationTaskModule.""" | ||
|
||
entities: AnnotationList[Span] = annotation_field(target="text") | ||
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") | ||
|
||
|
||
@dataclasses.dataclass | ||
class TextDocumentWithLabeledEntitiesAndRelations(TextBasedDocument): | ||
"""Possible input class for TransformerRETextClassificationTaskModule.""" | ||
|
||
class TextDocumentWithLabeledEntitiesAndEntityAttributes(TextBasedDocument): | ||
entities: AnnotationList[LabeledSpan] = annotation_field(target="text") | ||
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") | ||
|
||
|
||
@dataclasses.dataclass | ||
class DocumentWithEntitiesRelationsAndLabeledPartitions(TextBasedDocument): | ||
"""Possible input class for TransformerRETextClassificationTaskModule.""" | ||
|
||
entities: AnnotationList[LabeledSpan] = annotation_field(target="text") | ||
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") | ||
partitions: AnnotationList[LabeledSpan] = annotation_field(target="text") | ||
|
||
|
||
@dataclasses.dataclass | ||
class BratDocument(TextBasedDocument): | ||
"""Possible input class for TransformerRETextClassificationTaskModule.""" | ||
|
||
spans: AnnotationList[LabeledSpan] = annotation_field(target="text") | ||
relations: AnnotationList[BinaryRelation] = annotation_field(target="spans") | ||
span_attributions: AnnotationList[Attribute] = annotation_field(target="spans") | ||
relation_attributions: AnnotationList[Attribute] = annotation_field(target="relations") | ||
entity_attributes: AnnotationList[Attribute] = annotation_field(target="entities") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from pytorch_ie.annotations import LabeledSpan | ||
|
||
from src.document.types import Attribute, TextDocumentWithLabeledEntitiesAndEntityAttributes | ||
|
||
|
||
def test_attribute(): | ||
entity = LabeledSpan(start=0, end=1, label="PER") | ||
attribute = Attribute(annotation=entity, label="FACT") | ||
|
||
assert attribute.annotation == entity | ||
assert attribute.label == "FACT" | ||
assert attribute.type is None | ||
assert attribute.score is None | ||
|
||
assert str(attribute) == "Attribute(label=FACT)" | ||
|
||
|
||
def test_document(): | ||
doc = TextDocumentWithLabeledEntitiesAndEntityAttributes("He is really a person.") | ||
entity = LabeledSpan(start=0, end=2, label="PER") | ||
doc.entities.append(entity) | ||
attribute = Attribute(annotation=entity, label="FACT") | ||
doc.entity_attributes.append(attribute) | ||
|
||
assert str(doc.entities[0]) == "He" | ||
assert str(doc.entity_attributes[0]) == "Attribute(label=FACT,annotation=He)" |