diff --git a/src/pie_datasets/statistics/__init__.py b/src/pie_datasets/statistics/__init__.py new file mode 100644 index 00000000..27081d5b --- /dev/null +++ b/src/pie_datasets/statistics/__init__.py @@ -0,0 +1 @@ +from .span_length_collector import SpanLengthCollector diff --git a/src/pie_datasets/statistics.py b/src/pie_datasets/statistics/span_length_collector.py similarity index 53% rename from src/pie_datasets/statistics.py rename to src/pie_datasets/statistics/span_length_collector.py index 0a1ae0e4..9c553564 100644 --- a/src/pie_datasets/statistics.py +++ b/src/pie_datasets/statistics/span_length_collector.py @@ -13,70 +13,6 @@ logger = logging.getLogger(__name__) -class TokenCountCollector(DocumentStatistic): - """Collects the token count of a field when tokenizing its content with a Huggingface - tokenizer. - - The content of the field should be a string. - """ - - def __init__( - self, - tokenizer: Union[str, PreTrainedTokenizer], - text_field: str = "text", - tokenizer_kwargs: Optional[Dict[str, Any]] = None, - document_type: Optional[Type[Document]] = None, - **kwargs, - ): - if document_type is None and text_field == "text": - document_type = TextBasedDocument - super().__init__(document_type=document_type, **kwargs) - self.tokenizer = ( - AutoTokenizer.from_pretrained(tokenizer) if isinstance(tokenizer, str) else tokenizer - ) - self.tokenizer_kwargs = tokenizer_kwargs or {} - self.text_field = text_field - - def _collect(self, doc: Document) -> int: - text = getattr(doc, self.text_field) - encodings = self.tokenizer(text, **self.tokenizer_kwargs) - tokens = encodings.tokens() - return len(tokens) - - -class FieldLengthCollector(DocumentStatistic): - """Collects the length of a field, e.g. to collect the number the characters in the input text. - - The field should be a list of sized elements. - """ - - def __init__(self, field: str, **kwargs): - super().__init__(**kwargs) - self.field = field - - def _collect(self, doc: Document) -> int: - field_obj = getattr(doc, self.field) - return len(field_obj) - - -class SubFieldLengthCollector(DocumentStatistic): - """Collects the length of a subfield in a field, e.g. to collect the number of arguments of - N-ary relations.""" - - def __init__(self, field: str, subfield: str, **kwargs): - super().__init__(**kwargs) - self.field = field - self.subfield = subfield - - def _collect(self, doc: Document) -> List[int]: - field_obj = getattr(doc, self.field) - lengths = [] - for entry in field_obj: - subfield_obj = getattr(entry, self.subfield) - lengths.append(len(subfield_obj)) - return lengths - - class SpanLengthCollector(DocumentStatistic): """Collects the lengths of Span annotations. If labels are provided, the lengths collected per label. @@ -184,64 +120,3 @@ def _collect(self, doc: Document) -> Union[List[int], Dict[str, List[int]]]: values[label].append(length) return values if self.labels is not None else values["ALL"] - - -class DummyCollector(DocumentStatistic): - """A dummy collector that always returns 1, e.g. to count the number of documents. - - Can be used to count the number of documents. - """ - - DEFAULT_AGGREGATION_FUNCTIONS = ["sum"] - - def _collect(self, doc: Document) -> int: - return 1 - - -class LabelCountCollector(DocumentStatistic): - """Collects the number of field entries per label, e.g. to collect the number of entities per - type. - - The field should be a list of elements with a label attribute. - - Important: To make correct use of the result data, missing values need to be filled with 0, e.g.: - {("ORG",): [2, 3], ("LOC",): [2]} -> {("ORG",): [2, 3], ("LOC",): [2, 0]} - """ - - DEFAULT_AGGREGATION_FUNCTIONS = ["mean", "std", "min", "max", "len", "sum"] - - def __init__( - self, field: str, labels: Union[List[str], str], label_attribute: str = "label", **kwargs - ): - super().__init__(**kwargs) - self.field = field - self.label_attribute = label_attribute - if not (isinstance(labels, list) or labels == "INFERRED"): - raise ValueError("labels must be a list of strings or 'INFERRED'") - if labels == "INFERRED": - logger.warning( - f"Inferring labels with {self.__class__.__name__} from data produces wrong results " - f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values " - f"are not included in the calculation. We remove these aggregation functions from " - f"this collector, but be aware that the results may be wrong for your own aggregation " - f"functions that rely on zero values." - ) - self.aggregation_functions: Dict[str, Callable[[List], Any]] = { - name: func - for name, func in self.aggregation_functions.items() - if name not in ["mean", "std", "min"] - } - - self.labels = labels - - def _collect(self, doc: Document) -> Dict[str, int]: - field_obj = getattr(doc, self.field) - counts: Dict[str, int] - if self.labels == "INFERRED": - counts = defaultdict(int) - else: - counts = {label: 0 for label in self.labels} - for elem in field_obj: - label = getattr(elem, self.label_attribute) - counts[label] += 1 - return dict(counts) diff --git a/tests/unit/test_statistics.py b/tests/unit/test_statistics.py index d201dd65..1129adfb 100644 --- a/tests/unit/test_statistics.py +++ b/tests/unit/test_statistics.py @@ -6,14 +6,7 @@ from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument from pie_datasets import DatasetDict -from pie_datasets.statistics import ( - DummyCollector, - FieldLengthCollector, - LabelCountCollector, - SpanLengthCollector, - SubFieldLengthCollector, - TokenCountCollector, -) +from pie_datasets.statistics import SpanLengthCollector from tests import FIXTURES_ROOT @@ -30,115 +23,6 @@ class Conll2003Document(TextBasedDocument): def test_statistics(dataset): - statistic = DummyCollector() - values = statistic(dataset) - assert values == {"train": {"sum": 3}, "test": {"sum": 3}, "validation": {"sum": 3}} - - statistic = LabelCountCollector(field="entities", labels=["LOC", "PER", "ORG", "MISC"]) - values = statistic(dataset) - assert values == { - "train": { - "LOC": { - "mean": 0.3333333333333333, - "std": 0.4714045207910317, - "min": 0, - "max": 1, - "len": 3, - "sum": 1, - }, - "PER": { - "mean": 0.3333333333333333, - "std": 0.4714045207910317, - "min": 0, - "max": 1, - "len": 3, - "sum": 1, - }, - "ORG": { - "mean": 0.3333333333333333, - "std": 0.4714045207910317, - "min": 0, - "max": 1, - "len": 3, - "sum": 1, - }, - "MISC": { - "mean": 0.6666666666666666, - "std": 0.9428090415820634, - "min": 0, - "max": 2, - "len": 3, - "sum": 2, - }, - }, - "validation": { - "LOC": { - "mean": 0.3333333333333333, - "std": 0.4714045207910317, - "min": 0, - "max": 1, - "len": 3, - "sum": 1, - }, - "PER": { - "mean": 0.3333333333333333, - "std": 0.4714045207910317, - "min": 0, - "max": 1, - "len": 3, - "sum": 1, - }, - "ORG": {"mean": 1.0, "std": 0.816496580927726, "min": 0, "max": 2, "len": 3, "sum": 3}, - "MISC": { - "mean": 0.3333333333333333, - "std": 0.4714045207910317, - "min": 0, - "max": 1, - "len": 3, - "sum": 1, - }, - }, - "test": { - "LOC": {"mean": 1.0, "std": 0.816496580927726, "min": 0, "max": 2, "len": 3, "sum": 3}, - "PER": { - "mean": 0.6666666666666666, - "std": 0.4714045207910317, - "min": 0, - "max": 1, - "len": 3, - "sum": 2, - }, - "ORG": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 3, "sum": 0}, - "MISC": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 3, "sum": 0}, - }, - } - - statistic = LabelCountCollector(field="entities", labels="INFERRED") - values = statistic(dataset) - assert values == { - "train": { - "ORG": {"max": 1, "len": 1, "sum": 1}, - "MISC": {"max": 2, "len": 1, "sum": 2}, - "PER": {"max": 1, "len": 1, "sum": 1}, - "LOC": {"max": 1, "len": 1, "sum": 1}, - }, - "validation": { - "ORG": {"max": 2, "len": 2, "sum": 3}, - "LOC": {"max": 1, "len": 1, "sum": 1}, - "MISC": {"max": 1, "len": 1, "sum": 1}, - "PER": {"max": 1, "len": 1, "sum": 1}, - }, - "test": {"LOC": {"max": 2, "len": 2, "sum": 3}, "PER": {"max": 1, "len": 2, "sum": 2}}, - } - - statistic = FieldLengthCollector(field="text") - values = statistic(dataset) - assert values == { - "test": {"max": 57, "mean": 36.0, "min": 11, "std": 18.991226044325487}, - "train": {"max": 48, "mean": 27.333333333333332, "min": 15, "std": 14.70449666674185}, - "validation": {"max": 187, "mean": 89.66666666666667, "min": 17, "std": 71.5603863103665}, - } - statistic = SpanLengthCollector(layer="entities") values = statistic(dataset) assert values == { @@ -177,29 +61,8 @@ def test_statistics(dataset): }, } - # this is not super useful, we just collect the lengths of the labels, but it is enough to test the code - statistic = SubFieldLengthCollector(field="entities", subfield="label") - values = statistic(dataset) - assert values == { - "test": {"max": 3, "mean": 3.0, "min": 3, "std": 0.0}, - "train": {"max": 4, "mean": 3.4, "min": 3, "std": 0.4898979485566356}, - "validation": {"max": 4, "mean": 3.1666666666666665, "min": 3, "std": 0.3726779962499649}, - } - def test_statistics_with_tokenize(dataset): - statistic = TokenCountCollector( - text_field="text", - tokenizer="bert-base-uncased", - tokenizer_kwargs=dict(add_special_tokens=False), - ) - values = statistic(dataset) - assert values == { - "test": {"max": 12, "mean": 9.333333333333334, "min": 4, "std": 3.7712361663282534}, - "train": {"max": 9, "mean": 5.666666666666667, "min": 2, "std": 2.8674417556808756}, - "validation": {"max": 38, "mean": 18.333333333333332, "min": 6, "std": 14.055445761538678}, - } - @dataclasses.dataclass class TokenDocumentWithLabeledEntities(TokenBasedDocument): entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens")