From 09e8cacc649a85b1555d0719b7d9e4dee02f5fbe Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 8 Nov 2024 15:30:54 +0100 Subject: [PATCH 1/3] allow all aggregation methods for SpanLengthCollector when using labels=INFERRED --- src/pie_modules/metrics/span_length_collector.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/pie_modules/metrics/span_length_collector.py b/src/pie_modules/metrics/span_length_collector.py index 729c68a59..355902037 100644 --- a/src/pie_modules/metrics/span_length_collector.py +++ b/src/pie_modules/metrics/span_length_collector.py @@ -38,19 +38,6 @@ def __init__( self.layer = layer if isinstance(labels, str) and labels != "INFERRED": raise ValueError("labels must be a list of strings or 'INFERRED'") - if labels == "INFERRED": - logger.warning( - f"Inferring labels with {self.__class__.__name__} from data produces wrong results " - f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values " - f"are not included in the calculation. We remove these aggregation functions from " - f"this collector, but be aware that the results may be wrong for your own aggregation " - f"functions that rely on zero values." - ) - self.aggregation_functions: Dict[str, Callable[[List], Any]] = { - name: func - for name, func in self.aggregation_functions.items() - if name not in ["mean", "std", "min"] - } self.labels = labels self.label_field = label_attribute self.tokenize = tokenize From 6a05f28ec4eaa1189d40c30b551cb3209231b338 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 8 Nov 2024 15:43:03 +0100 Subject: [PATCH 2/3] fix test --- tests/metrics/test_span_length_collector.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/metrics/test_span_length_collector.py b/tests/metrics/test_span_length_collector.py index c55c864b2..615737b6d 100644 --- a/tests/metrics/test_span_length_collector.py +++ b/tests/metrics/test_span_length_collector.py @@ -62,7 +62,10 @@ def test_span_length_collector(documents): statistic = SpanLengthCollector(layer="entities", labels="INFERRED") values = statistic(documents) - assert values == {"org": {"len": 4, "max": 8}, "per": {"len": 3, "max": 10}} + assert values == { + "org": {"len": 4, "max": 8, "mean": 3.5, "min": 1, "std": 2.8722813232690143}, + "per": {"len": 3, "max": 10, "mean": 6.666666666666667, "min": 2, "std": 3.39934634239519}, + } def test_span_length_collector_wrong_label_value(): From 37756faeafbd71c45e983bcfc12c309bff5097d7 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 8 Nov 2024 15:49:16 +0100 Subject: [PATCH 3/3] check collected values --- tests/metrics/test_span_length_collector.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/metrics/test_span_length_collector.py b/tests/metrics/test_span_length_collector.py index 615737b6d..b52da21c0 100644 --- a/tests/metrics/test_span_length_collector.py +++ b/tests/metrics/test_span_length_collector.py @@ -52,6 +52,7 @@ def test_documents(documents): def test_span_length_collector(documents): statistic = SpanLengthCollector(layer="entities") values = statistic(documents) + assert statistic._values == [[8, 1, 2, 1], [10, 4, 8]] assert values == { "len": 7, "max": 10, @@ -62,6 +63,10 @@ def test_span_length_collector(documents): statistic = SpanLengthCollector(layer="entities", labels="INFERRED") values = statistic(documents) + assert [dict(v) for v in statistic._values] == [ + {"per": [8, 2], "org": [1, 1]}, + {"per": [10], "org": [4, 8]}, + ] assert values == { "org": {"len": 4, "max": 8, "mean": 3.5, "min": 1, "std": 2.8722813232690143}, "per": {"len": 3, "max": 10, "mean": 6.666666666666667, "min": 2, "std": 3.39934634239519},