diff --git a/src/pie_modules/metrics/__init__.py b/src/pie_modules/metrics/__init__.py index 9cf45196f..3468c807c 100644 --- a/src/pie_modules/metrics/__init__.py +++ b/src/pie_modules/metrics/__init__.py @@ -6,5 +6,6 @@ TokenCountCollector, ) +from .span_coverage_collector import SpanCoverageCollector from .span_length_collector import SpanLengthCollector from .squad_f1 import SQuADF1 diff --git a/src/pie_modules/metrics/span_coverage_collector.py b/src/pie_modules/metrics/span_coverage_collector.py new file mode 100644 index 000000000..72e495565 --- /dev/null +++ b/src/pie_modules/metrics/span_coverage_collector.py @@ -0,0 +1,105 @@ +import logging +from typing import Any, Dict, List, Optional, Set, Type, Union + +from pytorch_ie.annotations import Span +from pytorch_ie.core import Document, DocumentStatistic +from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument +from transformers import AutoTokenizer, PreTrainedTokenizer + +from pie_modules.annotations import LabeledMultiSpan +from pie_modules.document.processing import tokenize_document +from pie_modules.utils import resolve_type + +logger = logging.getLogger(__name__) + + +class SpanCoverageCollector(DocumentStatistic): + """Collects the coverage of Span annotations. It can handle overlapping spans. + + If a tokenizer is provided, the span coverage is calculated in means of tokens, otherwise in + means of characters. + + Args: + layer: The annotation layer of the document to calculate the span coverage for. + tokenize: Whether to tokenize the document before calculating the span coverage. Default is False. + tokenizer: The tokenizer to use for tokenization. Should be a PreTrainedTokenizer or a string + representing the name of a pre-trained tokenizer, e.g. "bert-base-uncased". Required if + tokenize is True. + tokenized_document_type: The type of the tokenized document or a string that can be resolved + to such a type. Required if tokenize is True. + tokenize_kwargs: Additional keyword arguments for the tokenization. + labels: If provided, only spans with these labels are considered. + label_attribute: The attribute of the span to consider as label. Default is "label". + """ + + DEFAULT_AGGREGATION_FUNCTIONS = ["len", "mean", "std", "min", "max"] + + def __init__( + self, + layer: str, + tokenize: bool = False, + tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, + tokenized_document_type: Optional[Union[str, Type[TokenBasedDocument]]] = None, + labels: Optional[Union[List[str], str]] = None, + label_attribute: str = "label", + tokenize_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.layer = layer + self.labels = labels + self.label_field = label_attribute + self.tokenize = tokenize + if self.tokenize: + if tokenizer is None: + raise ValueError( + "tokenizer must be provided to calculate the span coverage in means of tokens" + ) + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + self.tokenizer = tokenizer + if tokenized_document_type is None: + raise ValueError( + "tokenized_document_type must be provided to calculate the span coverage in means of tokens" + ) + self.tokenized_document_type = resolve_type( + tokenized_document_type, expected_super_type=TokenBasedDocument + ) + self.tokenize_kwargs = tokenize_kwargs or {} + + def _collect(self, doc: Document) -> float: + docs: Union[List[Document], List[TokenBasedDocument]] + if self.tokenize: + if not isinstance(doc, TextBasedDocument): + raise ValueError( + "doc must be a TextBasedDocument to calculate the span coverage in means of tokens" + ) + docs = tokenize_document( + doc, + tokenizer=self.tokenizer, + result_document_type=self.tokenized_document_type, + **self.tokenize_kwargs, + ) + if len(docs) != 1: + raise ValueError( + "tokenization of a single document must result in a single document to calculate the " + "span coverage correctly. Please check your tokenization settings, especially that " + "no windowing is applied because of max input length restrictions." + ) + doc = docs[0] + + layer_obj = getattr(doc, self.layer) + target = layer_obj.target + covered_indices: Set[int] = set() + for span in layer_obj: + if self.labels is not None and getattr(span, self.label_field) not in self.labels: + continue + if isinstance(span, Span): + covered_indices.update(range(span.start, span.end)) + elif isinstance(span, LabeledMultiSpan): + for start, end in span.slices: + covered_indices.update(range(start, end)) + else: + raise TypeError(f"span coverage calculation is not yet supported for {type(span)}") + + return len(covered_indices) / len(target) diff --git a/tests/metrics/test_span_coverage_collector.py b/tests/metrics/test_span_coverage_collector.py new file mode 100644 index 000000000..7329a867d --- /dev/null +++ b/tests/metrics/test_span_coverage_collector.py @@ -0,0 +1,155 @@ +import dataclasses + +import pytest +from pytorch_ie import Annotation, Document +from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument + +from pie_modules.annotations import LabeledMultiSpan +from pie_modules.metrics import SpanCoverageCollector + + +@dataclasses.dataclass +class TestDocument(TextBasedDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + + +def test_span_coverage_collector(): + doc = TestDocument(text="A and O.") + doc.entities.append(LabeledSpan(start=0, end=1, label="entity")) + doc.entities.append(LabeledSpan(start=6, end=7, label="entity")) + + statistic = SpanCoverageCollector(layer="entities") + values = statistic(doc) + assert values == {"len": 1, "max": 0.25, "mean": 0.25, "min": 0.25, "std": 0.0} + + +def test_span_coverage_collector_with_multi_span(): + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + entities: AnnotationList[LabeledMultiSpan] = annotation_field(target="text") + + doc = TestDocument(text="A and O.") + doc.entities.append(LabeledMultiSpan(slices=((0, 1),), label="entity")) + doc.entities.append(LabeledMultiSpan(slices=((6, 7),), label="entity")) + + statistic = SpanCoverageCollector(layer="entities") + values = statistic(doc) + assert values == { + "len": 1, + "max": 0.25, + "mean": 0.25, + "min": 0.25, + "std": 0.0, + } + + +def test_span_coverage_collector_with_labels(): + doc = TestDocument(text="A and O.") + doc.entities.append(LabeledSpan(start=0, end=1, label="entity")) + doc.entities.append(LabeledSpan(start=6, end=7, label="no_entity")) + + statistic = SpanCoverageCollector(layer="entities", labels=["entity"]) + values = statistic(doc) + assert values == {"len": 1, "max": 0.125, "mean": 0.125, "min": 0.125, "std": 0.0} + + +def test_span_coverage_collector_with_tokenize(): + doc = TestDocument(text="A and O.") + doc.entities.append(LabeledSpan(start=0, end=1, label="entity")) + doc.entities.append(LabeledSpan(start=6, end=7, label="entity")) + + @dataclasses.dataclass + class TokenizedTestDocument(TokenBasedDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens") + + statistic = SpanCoverageCollector( + layer="entities", + tokenize=True, + tokenizer="bert-base-uncased", + tokenized_document_type=TokenizedTestDocument, + ) + values = statistic(doc) + assert values == { + "len": 1, + "max": 0.3333333333333333, + "mean": 0.3333333333333333, + "min": 0.3333333333333333, + "std": 0.0, + } + + +def test_span_coverage_collector_with_tokenize_missing_tokenizer(): + with pytest.raises(ValueError) as excinfo: + SpanCoverageCollector( + layer="entities", + tokenize=True, + tokenized_document_type=TokenBasedDocument, + ) + assert ( + str(excinfo.value) + == "tokenizer must be provided to calculate the span coverage in means of tokens" + ) + + +def test_span_coverage_collector_with_tokenize_missing_tokenized_document_type(): + with pytest.raises(ValueError) as excinfo: + SpanCoverageCollector( + layer="entities", + tokenize=True, + tokenizer="bert-base-uncased", + ) + assert ( + str(excinfo.value) + == "tokenized_document_type must be provided to calculate the span coverage in means of tokens" + ) + + +def test_span_coverage_collector_with_tokenize_wrong_document_type(): + @dataclasses.dataclass + class TestDocument(Document): + data: str + entities: AnnotationList[LabeledSpan] = annotation_field(target="data") + + doc = TestDocument(data="A and O") + + @dataclasses.dataclass + class TokenizedTestDocument(TokenBasedDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens") + + statistic = SpanCoverageCollector( + layer="entities", + tokenize=True, + tokenizer="bert-base-uncased", + tokenized_document_type=TokenizedTestDocument, + ) + + with pytest.raises(ValueError) as excinfo: + statistic(doc) + assert ( + str(excinfo.value) + == "doc must be a TextBasedDocument to calculate the span coverage in means of tokens" + ) + + +def test_span_coverage_collector_with_tokenize_wrong_annotation_type(): + @dataclasses.dataclass(eq=True, frozen=True) + class UnknownSpan(Annotation): + start: int + end: int + + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + labeled_multi_spans: AnnotationList[UnknownSpan] = annotation_field(target="text") + + doc = TestDocument(text="First sentence. Entity M works at N. And it founded O.") + doc.labeled_multi_spans.append(UnknownSpan(start=16, end=24)) + + statistic = SpanCoverageCollector(layer="labeled_multi_spans") + + with pytest.raises(TypeError) as excinfo: + statistic(doc) + assert ( + str(excinfo.value) == f"span coverage calculation is not yet supported for {UnknownSpan}" + )