diff --git a/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py b/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py index b30daa45f..86acdba94 100644 --- a/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py +++ b/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Callable, Dict, Generic, List, TypeVar, Union +from collections.abc import Collection +from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union from torch import Tensor from torchmetrics import Metric, MetricCollection @@ -16,8 +17,11 @@ class WrappedMetricWithPrepareFunction(WrapperMetric, Generic[T]): Args: metric: The metric to wrap. It should be a subclass of torchmetrics.Metric. - prepare_function: A function that prepares the input for the metric. It is called with - the predictions as well as the targets. + prepare_function: A function that prepares the input for the metric. If provided, It is called with + the predictions as well as the targets (separately). + prepare_together_function: A function that prepares both the predictions and the targets together and + should return them as a tuple. If provided, it is called with the predictions and the targets as + arguments. prepare_does_unbatch: If True, the prepare_function is expected to return an iterable of individual inputs. This can be used to un-batch the input before passing it to the wrapped metric. @@ -26,49 +30,65 @@ class WrappedMetricWithPrepareFunction(WrapperMetric, Generic[T]): def __init__( self, metric: Union[Metric, MetricCollection], - prepare_function: Callable[[T], Any], + prepare_function: Optional[Callable[[T], Any]] = None, + prepare_together_function: Optional[Callable[[T, T], Tuple[Any, Any]]] = None, prepare_does_unbatch: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) self.metric = metric self.prepare_function = prepare_function + self.prepare_both_function = prepare_together_function self.prepare_does_unbatch = prepare_does_unbatch def forward(self, prediction: T, target: T) -> Any: - prediction_prepared = self.prepare_function(prediction) - target_prepared = self.prepare_function(target) + if self.prepare_function is not None: + prediction = self.prepare_function(prediction) + target = self.prepare_function(target) + if self.prepare_both_function is not None: + prediction, target = self.prepare_both_function(prediction, target) if self.prepare_does_unbatch: - if len(prediction_prepared) != len(target_prepared): + if not isinstance(prediction, Collection) or not isinstance(target, Collection): raise ValueError( - f"Number of prepared predictions ({len(prediction_prepared)}) and targets " - f"({len(target_prepared)}) do not match." + "Both prediction and target need to be iterable and sized when prepare_does_unbatch=True." ) - if len(prediction_prepared) == 0: + if len(prediction) != len(target): + raise ValueError( + f"Number of prepared predictions ({len(prediction)}) and targets " + f"({len(target)}) do not match." + ) + if len(prediction) == 0: raise ValueError("Empty batch.") results = [] - for prediction_str, target_str in zip(prediction_prepared, target_prepared): + for prediction_str, target_str in zip(prediction, target): current_result = self.metric(prediction_str, target_str) results.append(current_result) return results else: - return self.metric(prediction_prepared, target_prepared) + return self.metric(prediction, target) def update(self, prediction: T, target: T) -> None: - prediction_prepared = self.prepare_function(prediction) - target_prepared = self.prepare_function(target) + if self.prepare_function is not None: + prediction = self.prepare_function(prediction) + target = self.prepare_function(target) + if self.prepare_both_function is not None: + prediction, target = self.prepare_both_function(prediction, target) if self.prepare_does_unbatch: - if len(prediction_prepared) != len(target_prepared): + if not isinstance(prediction, Collection) or not isinstance(target, Collection): + raise ValueError( + "Both prediction and target need to be iterable and sized when prepare_does_unbatch=True." + ) + if len(prediction) != len(target): raise ValueError( - f"Number of prepared predictions ({len(prediction_prepared)}) and targets " - f"({len(target_prepared)}) do not match." + f"Number of prepared predictions ({len(prediction)}) and targets " + f"({len(target)}) do not match." ) - if len(prediction_prepared) == 0: + if len(prediction) == 0: raise ValueError("Empty batch.") - for prediction_str, target_str in zip(prediction_prepared, target_prepared): + for prediction_str, target_str in zip(prediction, target): self.metric.update(prediction_str, target_str) else: - self.metric.update(prediction_prepared, target_prepared) + self.metric.update(prediction, target) def compute(self) -> Any: return self.metric.compute() diff --git a/tests/taskmodules/metrics/test_wrapped_metric_with_prepare_function.py b/tests/taskmodules/metrics/test_wrapped_metric_with_prepare_function.py index 6caa5da0c..e06130b09 100644 --- a/tests/taskmodules/metrics/test_wrapped_metric_with_prepare_function.py +++ b/tests/taskmodules/metrics/test_wrapped_metric_with_prepare_function.py @@ -1,3 +1,5 @@ +from functools import partial + import pytest from torchmetrics import Metric @@ -48,6 +50,52 @@ def test_metric(): assert metric.compute() == 0.0 +def split_both_and_remove_where_both_match( + preds: str, targets: str, match: str +) -> tuple[list[str], list[str]]: + preds = preds.split() + targets = targets.split() + not_both_none_indices = [ + i for i, (p, t) in enumerate(zip(preds, targets)) if p != match or t != match + ] + preds = [preds[i] for i in not_both_none_indices] + targets = [targets[i] for i in not_both_none_indices] + return preds, targets + + +def test_wrapped_metric_with_prepare_both_function(): + metric = WrappedMetricWithPrepareFunction( + metric=TestMetric(), + prepare_together_function=partial(split_both_and_remove_where_both_match, match="none"), + prepare_does_unbatch=True, + ) + + assert metric is not None + assert metric.prepare_both_function is not None + + assert metric.compute() == 0.0 + + # none is removed from both, remaining is the same + metric.reset() + metric(prediction="abc none", target="abc none") + assert metric.compute() == 1.0 + + # none is removed from both, remaining is different + metric.reset() + metric(prediction="abc none", target="def none") + assert metric.compute() == 0.0 + + # none is not removed from both, remaining is partially the same + metric.reset() + metric(prediction="abc def", target="abc none") + assert metric.compute() == 0.5 + + # none is not removed from both, remaining is different + metric.reset() + metric(prediction="abc def", target="def none") + assert metric.compute() == 0.0 + + @pytest.fixture(scope="module") def wrapped_metric_with_unbatch_function(): # just split the strings to unbatch the inputs