Skip to content

Commit

Permalink
Merge pull request #68 from ArneBinder/span_coverage_statistic
Browse files Browse the repository at this point in the history
implement `SpanCoverageCollector` statistic
  • Loading branch information
ArneBinder authored Feb 27, 2024
2 parents e932a37 + ebdee69 commit d87c82e
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/pie_modules/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
TokenCountCollector,
)

from .span_coverage_collector import SpanCoverageCollector
from .span_length_collector import SpanLengthCollector
from .squad_f1 import SQuADF1
105 changes: 105 additions & 0 deletions src/pie_modules/metrics/span_coverage_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import logging
from typing import Any, Dict, List, Optional, Set, Type, Union

from pytorch_ie.annotations import Span
from pytorch_ie.core import Document, DocumentStatistic
from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument
from transformers import AutoTokenizer, PreTrainedTokenizer

from pie_modules.annotations import LabeledMultiSpan
from pie_modules.document.processing import tokenize_document
from pie_modules.utils import resolve_type

logger = logging.getLogger(__name__)


class SpanCoverageCollector(DocumentStatistic):
"""Collects the coverage of Span annotations. It can handle overlapping spans.
If a tokenizer is provided, the span coverage is calculated in means of tokens, otherwise in
means of characters.
Args:
layer: The annotation layer of the document to calculate the span coverage for.
tokenize: Whether to tokenize the document before calculating the span coverage. Default is False.
tokenizer: The tokenizer to use for tokenization. Should be a PreTrainedTokenizer or a string
representing the name of a pre-trained tokenizer, e.g. "bert-base-uncased". Required if
tokenize is True.
tokenized_document_type: The type of the tokenized document or a string that can be resolved
to such a type. Required if tokenize is True.
tokenize_kwargs: Additional keyword arguments for the tokenization.
labels: If provided, only spans with these labels are considered.
label_attribute: The attribute of the span to consider as label. Default is "label".
"""

DEFAULT_AGGREGATION_FUNCTIONS = ["len", "mean", "std", "min", "max"]

def __init__(
self,
layer: str,
tokenize: bool = False,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
tokenized_document_type: Optional[Union[str, Type[TokenBasedDocument]]] = None,
labels: Optional[Union[List[str], str]] = None,
label_attribute: str = "label",
tokenize_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.layer = layer
self.labels = labels
self.label_field = label_attribute
self.tokenize = tokenize
if self.tokenize:
if tokenizer is None:
raise ValueError(
"tokenizer must be provided to calculate the span coverage in means of tokens"
)
if isinstance(tokenizer, str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self.tokenizer = tokenizer
if tokenized_document_type is None:
raise ValueError(
"tokenized_document_type must be provided to calculate the span coverage in means of tokens"
)
self.tokenized_document_type = resolve_type(
tokenized_document_type, expected_super_type=TokenBasedDocument
)
self.tokenize_kwargs = tokenize_kwargs or {}

def _collect(self, doc: Document) -> float:
docs: Union[List[Document], List[TokenBasedDocument]]
if self.tokenize:
if not isinstance(doc, TextBasedDocument):
raise ValueError(
"doc must be a TextBasedDocument to calculate the span coverage in means of tokens"
)
docs = tokenize_document(
doc,
tokenizer=self.tokenizer,
result_document_type=self.tokenized_document_type,
**self.tokenize_kwargs,
)
if len(docs) != 1:
raise ValueError(
"tokenization of a single document must result in a single document to calculate the "
"span coverage correctly. Please check your tokenization settings, especially that "
"no windowing is applied because of max input length restrictions."
)
doc = docs[0]

layer_obj = getattr(doc, self.layer)
target = layer_obj.target
covered_indices: Set[int] = set()
for span in layer_obj:
if self.labels is not None and getattr(span, self.label_field) not in self.labels:
continue
if isinstance(span, Span):
covered_indices.update(range(span.start, span.end))
elif isinstance(span, LabeledMultiSpan):
for start, end in span.slices:
covered_indices.update(range(start, end))
else:
raise TypeError(f"span coverage calculation is not yet supported for {type(span)}")

return len(covered_indices) / len(target)
155 changes: 155 additions & 0 deletions tests/metrics/test_span_coverage_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import dataclasses

import pytest
from pytorch_ie import Annotation, Document
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument

from pie_modules.annotations import LabeledMultiSpan
from pie_modules.metrics import SpanCoverageCollector


@dataclasses.dataclass
class TestDocument(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")


def test_span_coverage_collector():
doc = TestDocument(text="A and O.")
doc.entities.append(LabeledSpan(start=0, end=1, label="entity"))
doc.entities.append(LabeledSpan(start=6, end=7, label="entity"))

statistic = SpanCoverageCollector(layer="entities")
values = statistic(doc)
assert values == {"len": 1, "max": 0.25, "mean": 0.25, "min": 0.25, "std": 0.0}


def test_span_coverage_collector_with_multi_span():
@dataclasses.dataclass
class TestDocument(TextBasedDocument):
entities: AnnotationList[LabeledMultiSpan] = annotation_field(target="text")

doc = TestDocument(text="A and O.")
doc.entities.append(LabeledMultiSpan(slices=((0, 1),), label="entity"))
doc.entities.append(LabeledMultiSpan(slices=((6, 7),), label="entity"))

statistic = SpanCoverageCollector(layer="entities")
values = statistic(doc)
assert values == {
"len": 1,
"max": 0.25,
"mean": 0.25,
"min": 0.25,
"std": 0.0,
}


def test_span_coverage_collector_with_labels():
doc = TestDocument(text="A and O.")
doc.entities.append(LabeledSpan(start=0, end=1, label="entity"))
doc.entities.append(LabeledSpan(start=6, end=7, label="no_entity"))

statistic = SpanCoverageCollector(layer="entities", labels=["entity"])
values = statistic(doc)
assert values == {"len": 1, "max": 0.125, "mean": 0.125, "min": 0.125, "std": 0.0}


def test_span_coverage_collector_with_tokenize():
doc = TestDocument(text="A and O.")
doc.entities.append(LabeledSpan(start=0, end=1, label="entity"))
doc.entities.append(LabeledSpan(start=6, end=7, label="entity"))

@dataclasses.dataclass
class TokenizedTestDocument(TokenBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens")

statistic = SpanCoverageCollector(
layer="entities",
tokenize=True,
tokenizer="bert-base-uncased",
tokenized_document_type=TokenizedTestDocument,
)
values = statistic(doc)
assert values == {
"len": 1,
"max": 0.3333333333333333,
"mean": 0.3333333333333333,
"min": 0.3333333333333333,
"std": 0.0,
}


def test_span_coverage_collector_with_tokenize_missing_tokenizer():
with pytest.raises(ValueError) as excinfo:
SpanCoverageCollector(
layer="entities",
tokenize=True,
tokenized_document_type=TokenBasedDocument,
)
assert (
str(excinfo.value)
== "tokenizer must be provided to calculate the span coverage in means of tokens"
)


def test_span_coverage_collector_with_tokenize_missing_tokenized_document_type():
with pytest.raises(ValueError) as excinfo:
SpanCoverageCollector(
layer="entities",
tokenize=True,
tokenizer="bert-base-uncased",
)
assert (
str(excinfo.value)
== "tokenized_document_type must be provided to calculate the span coverage in means of tokens"
)


def test_span_coverage_collector_with_tokenize_wrong_document_type():
@dataclasses.dataclass
class TestDocument(Document):
data: str
entities: AnnotationList[LabeledSpan] = annotation_field(target="data")

doc = TestDocument(data="A and O")

@dataclasses.dataclass
class TokenizedTestDocument(TokenBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens")

statistic = SpanCoverageCollector(
layer="entities",
tokenize=True,
tokenizer="bert-base-uncased",
tokenized_document_type=TokenizedTestDocument,
)

with pytest.raises(ValueError) as excinfo:
statistic(doc)
assert (
str(excinfo.value)
== "doc must be a TextBasedDocument to calculate the span coverage in means of tokens"
)


def test_span_coverage_collector_with_tokenize_wrong_annotation_type():
@dataclasses.dataclass(eq=True, frozen=True)
class UnknownSpan(Annotation):
start: int
end: int

@dataclasses.dataclass
class TestDocument(TextBasedDocument):
labeled_multi_spans: AnnotationList[UnknownSpan] = annotation_field(target="text")

doc = TestDocument(text="First sentence. Entity M works at N. And it founded O.")
doc.labeled_multi_spans.append(UnknownSpan(start=16, end=24))

statistic = SpanCoverageCollector(layer="labeled_multi_spans")

with pytest.raises(TypeError) as excinfo:
statistic(doc)
assert (
str(excinfo.value) == f"span coverage calculation is not yet supported for {UnknownSpan}"
)

0 comments on commit d87c82e

Please sign in to comment.