Skip to content

Commit

Permalink
Merge pull request #141 from ArneBinder/span_length_collector/allow_a…
Browse files Browse the repository at this point in the history
…ll_aggr_for_inferred

allow all aggregation methods for `SpanLengthCollector` when using `labels=INFERRED`
  • Loading branch information
ArneBinder authored Nov 8, 2024
2 parents 6b64aa5 + 37756fa commit fb55896
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 14 deletions.
13 changes: 0 additions & 13 deletions src/pie_modules/metrics/span_length_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,6 @@ def __init__(
self.layer = layer
if isinstance(labels, str) and labels != "INFERRED":
raise ValueError("labels must be a list of strings or 'INFERRED'")
if labels == "INFERRED":
logger.warning(
f"Inferring labels with {self.__class__.__name__} from data produces wrong results "
f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values "
f"are not included in the calculation. We remove these aggregation functions from "
f"this collector, but be aware that the results may be wrong for your own aggregation "
f"functions that rely on zero values."
)
self.aggregation_functions: Dict[str, Callable[[List], Any]] = {
name: func
for name, func in self.aggregation_functions.items()
if name not in ["mean", "std", "min"]
}
self.labels = labels
self.label_field = label_attribute
self.tokenize = tokenize
Expand Down
10 changes: 9 additions & 1 deletion tests/metrics/test_span_length_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_documents(documents):
def test_span_length_collector(documents):
statistic = SpanLengthCollector(layer="entities")
values = statistic(documents)
assert statistic._values == [[8, 1, 2, 1], [10, 4, 8]]
assert values == {
"len": 7,
"max": 10,
Expand All @@ -62,7 +63,14 @@ def test_span_length_collector(documents):

statistic = SpanLengthCollector(layer="entities", labels="INFERRED")
values = statistic(documents)
assert values == {"org": {"len": 4, "max": 8}, "per": {"len": 3, "max": 10}}
assert [dict(v) for v in statistic._values] == [
{"per": [8, 2], "org": [1, 1]},
{"per": [10], "org": [4, 8]},
]
assert values == {
"org": {"len": 4, "max": 8, "mean": 3.5, "min": 1, "std": 2.8722813232690143},
"per": {"len": 3, "max": 10, "mean": 6.666666666666667, "min": 2, "std": 3.39934634239519},
}


def test_span_length_collector_wrong_label_value():
Expand Down

0 comments on commit fb55896

Please sign in to comment.