Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove some statistics #35

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pie_datasets/statistics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .span_length_collector import SpanLengthCollector
Original file line number Diff line number Diff line change
Expand Up @@ -13,70 +13,6 @@
logger = logging.getLogger(__name__)


class TokenCountCollector(DocumentStatistic):
"""Collects the token count of a field when tokenizing its content with a Huggingface
tokenizer.

The content of the field should be a string.
"""

def __init__(
self,
tokenizer: Union[str, PreTrainedTokenizer],
text_field: str = "text",
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
document_type: Optional[Type[Document]] = None,
**kwargs,
):
if document_type is None and text_field == "text":
document_type = TextBasedDocument
super().__init__(document_type=document_type, **kwargs)
self.tokenizer = (
AutoTokenizer.from_pretrained(tokenizer) if isinstance(tokenizer, str) else tokenizer
)
self.tokenizer_kwargs = tokenizer_kwargs or {}
self.text_field = text_field

def _collect(self, doc: Document) -> int:
text = getattr(doc, self.text_field)
encodings = self.tokenizer(text, **self.tokenizer_kwargs)
tokens = encodings.tokens()
return len(tokens)


class FieldLengthCollector(DocumentStatistic):
"""Collects the length of a field, e.g. to collect the number the characters in the input text.

The field should be a list of sized elements.
"""

def __init__(self, field: str, **kwargs):
super().__init__(**kwargs)
self.field = field

def _collect(self, doc: Document) -> int:
field_obj = getattr(doc, self.field)
return len(field_obj)


class SubFieldLengthCollector(DocumentStatistic):
"""Collects the length of a subfield in a field, e.g. to collect the number of arguments of
N-ary relations."""

def __init__(self, field: str, subfield: str, **kwargs):
super().__init__(**kwargs)
self.field = field
self.subfield = subfield

def _collect(self, doc: Document) -> List[int]:
field_obj = getattr(doc, self.field)
lengths = []
for entry in field_obj:
subfield_obj = getattr(entry, self.subfield)
lengths.append(len(subfield_obj))
return lengths


class SpanLengthCollector(DocumentStatistic):
"""Collects the lengths of Span annotations. If labels are provided, the lengths collected per
label.
Expand Down Expand Up @@ -184,64 +120,3 @@ def _collect(self, doc: Document) -> Union[List[int], Dict[str, List[int]]]:
values[label].append(length)

return values if self.labels is not None else values["ALL"]


class DummyCollector(DocumentStatistic):
"""A dummy collector that always returns 1, e.g. to count the number of documents.

Can be used to count the number of documents.
"""

DEFAULT_AGGREGATION_FUNCTIONS = ["sum"]

def _collect(self, doc: Document) -> int:
return 1


class LabelCountCollector(DocumentStatistic):
"""Collects the number of field entries per label, e.g. to collect the number of entities per
type.

The field should be a list of elements with a label attribute.

Important: To make correct use of the result data, missing values need to be filled with 0, e.g.:
{("ORG",): [2, 3], ("LOC",): [2]} -> {("ORG",): [2, 3], ("LOC",): [2, 0]}
"""

DEFAULT_AGGREGATION_FUNCTIONS = ["mean", "std", "min", "max", "len", "sum"]

def __init__(
self, field: str, labels: Union[List[str], str], label_attribute: str = "label", **kwargs
):
super().__init__(**kwargs)
self.field = field
self.label_attribute = label_attribute
if not (isinstance(labels, list) or labels == "INFERRED"):
raise ValueError("labels must be a list of strings or 'INFERRED'")
if labels == "INFERRED":
logger.warning(
f"Inferring labels with {self.__class__.__name__} from data produces wrong results "
f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values "
f"are not included in the calculation. We remove these aggregation functions from "
f"this collector, but be aware that the results may be wrong for your own aggregation "
f"functions that rely on zero values."
)
self.aggregation_functions: Dict[str, Callable[[List], Any]] = {
name: func
for name, func in self.aggregation_functions.items()
if name not in ["mean", "std", "min"]
}

self.labels = labels

def _collect(self, doc: Document) -> Dict[str, int]:
field_obj = getattr(doc, self.field)
counts: Dict[str, int]
if self.labels == "INFERRED":
counts = defaultdict(int)
else:
counts = {label: 0 for label in self.labels}
for elem in field_obj:
label = getattr(elem, self.label_attribute)
counts[label] += 1
return dict(counts)
139 changes: 1 addition & 138 deletions tests/unit/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,7 @@
from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument

from pie_datasets import DatasetDict
from pie_datasets.statistics import (
DummyCollector,
FieldLengthCollector,
LabelCountCollector,
SpanLengthCollector,
SubFieldLengthCollector,
TokenCountCollector,
)
from pie_datasets.statistics import SpanLengthCollector
from tests import FIXTURES_ROOT


Expand All @@ -30,115 +23,6 @@ class Conll2003Document(TextBasedDocument):


def test_statistics(dataset):
statistic = DummyCollector()
values = statistic(dataset)
assert values == {"train": {"sum": 3}, "test": {"sum": 3}, "validation": {"sum": 3}}

statistic = LabelCountCollector(field="entities", labels=["LOC", "PER", "ORG", "MISC"])
values = statistic(dataset)
assert values == {
"train": {
"LOC": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
"PER": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
"ORG": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
"MISC": {
"mean": 0.6666666666666666,
"std": 0.9428090415820634,
"min": 0,
"max": 2,
"len": 3,
"sum": 2,
},
},
"validation": {
"LOC": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
"PER": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
"ORG": {"mean": 1.0, "std": 0.816496580927726, "min": 0, "max": 2, "len": 3, "sum": 3},
"MISC": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
},
"test": {
"LOC": {"mean": 1.0, "std": 0.816496580927726, "min": 0, "max": 2, "len": 3, "sum": 3},
"PER": {
"mean": 0.6666666666666666,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 2,
},
"ORG": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 3, "sum": 0},
"MISC": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 3, "sum": 0},
},
}

statistic = LabelCountCollector(field="entities", labels="INFERRED")
values = statistic(dataset)
assert values == {
"train": {
"ORG": {"max": 1, "len": 1, "sum": 1},
"MISC": {"max": 2, "len": 1, "sum": 2},
"PER": {"max": 1, "len": 1, "sum": 1},
"LOC": {"max": 1, "len": 1, "sum": 1},
},
"validation": {
"ORG": {"max": 2, "len": 2, "sum": 3},
"LOC": {"max": 1, "len": 1, "sum": 1},
"MISC": {"max": 1, "len": 1, "sum": 1},
"PER": {"max": 1, "len": 1, "sum": 1},
},
"test": {"LOC": {"max": 2, "len": 2, "sum": 3}, "PER": {"max": 1, "len": 2, "sum": 2}},
}

statistic = FieldLengthCollector(field="text")
values = statistic(dataset)
assert values == {
"test": {"max": 57, "mean": 36.0, "min": 11, "std": 18.991226044325487},
"train": {"max": 48, "mean": 27.333333333333332, "min": 15, "std": 14.70449666674185},
"validation": {"max": 187, "mean": 89.66666666666667, "min": 17, "std": 71.5603863103665},
}

statistic = SpanLengthCollector(layer="entities")
values = statistic(dataset)
assert values == {
Expand Down Expand Up @@ -177,29 +61,8 @@ def test_statistics(dataset):
},
}

# this is not super useful, we just collect the lengths of the labels, but it is enough to test the code
statistic = SubFieldLengthCollector(field="entities", subfield="label")
values = statistic(dataset)
assert values == {
"test": {"max": 3, "mean": 3.0, "min": 3, "std": 0.0},
"train": {"max": 4, "mean": 3.4, "min": 3, "std": 0.4898979485566356},
"validation": {"max": 4, "mean": 3.1666666666666665, "min": 3, "std": 0.3726779962499649},
}


def test_statistics_with_tokenize(dataset):
statistic = TokenCountCollector(
text_field="text",
tokenizer="bert-base-uncased",
tokenizer_kwargs=dict(add_special_tokens=False),
)
values = statistic(dataset)
assert values == {
"test": {"max": 12, "mean": 9.333333333333334, "min": 4, "std": 3.7712361663282534},
"train": {"max": 9, "mean": 5.666666666666667, "min": 2, "std": 2.8674417556808756},
"validation": {"max": 38, "mean": 18.333333333333332, "min": 6, "std": 14.055445761538678},
}

@dataclasses.dataclass
class TokenDocumentWithLabeledEntities(TokenBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens")
Expand Down
Loading