Skip to content

Commit

Permalink
Merge pull request #129 from ArneBinder/wrapped_metric/prepare_togeth…
Browse files Browse the repository at this point in the history
…er_function

add parameter `prepare_together_function` to `WrappedMetricWithPrepareFunction`
  • Loading branch information
ArneBinder authored Oct 8, 2024
2 parents 3c31d85 + 475fd0d commit 218fc87
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Any, Callable, Dict, Generic, List, TypeVar, Union
from collections.abc import Collection
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union

from torch import Tensor
from torchmetrics import Metric, MetricCollection
Expand All @@ -16,8 +17,11 @@ class WrappedMetricWithPrepareFunction(WrapperMetric, Generic[T]):
Args:
metric: The metric to wrap. It should be a subclass of torchmetrics.Metric.
prepare_function: A function that prepares the input for the metric. It is called with
the predictions as well as the targets.
prepare_function: A function that prepares the input for the metric. If provided, It is called with
the predictions as well as the targets (separately).
prepare_together_function: A function that prepares both the predictions and the targets together and
should return them as a tuple. If provided, it is called with the predictions and the targets as
arguments.
prepare_does_unbatch: If True, the prepare_function is expected to return an iterable of
individual inputs. This can be used to un-batch the input before passing it to the
wrapped metric.
Expand All @@ -26,49 +30,65 @@ class WrappedMetricWithPrepareFunction(WrapperMetric, Generic[T]):
def __init__(
self,
metric: Union[Metric, MetricCollection],
prepare_function: Callable[[T], Any],
prepare_function: Optional[Callable[[T], Any]] = None,
prepare_together_function: Optional[Callable[[T, T], Tuple[Any, Any]]] = None,
prepare_does_unbatch: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.metric = metric
self.prepare_function = prepare_function
self.prepare_both_function = prepare_together_function
self.prepare_does_unbatch = prepare_does_unbatch

def forward(self, prediction: T, target: T) -> Any:
prediction_prepared = self.prepare_function(prediction)
target_prepared = self.prepare_function(target)
if self.prepare_function is not None:
prediction = self.prepare_function(prediction)
target = self.prepare_function(target)
if self.prepare_both_function is not None:
prediction, target = self.prepare_both_function(prediction, target)
if self.prepare_does_unbatch:
if len(prediction_prepared) != len(target_prepared):
if not isinstance(prediction, Collection) or not isinstance(target, Collection):
raise ValueError(
f"Number of prepared predictions ({len(prediction_prepared)}) and targets "
f"({len(target_prepared)}) do not match."
"Both prediction and target need to be iterable and sized when prepare_does_unbatch=True."
)
if len(prediction_prepared) == 0:
if len(prediction) != len(target):
raise ValueError(
f"Number of prepared predictions ({len(prediction)}) and targets "
f"({len(target)}) do not match."
)
if len(prediction) == 0:
raise ValueError("Empty batch.")
results = []
for prediction_str, target_str in zip(prediction_prepared, target_prepared):
for prediction_str, target_str in zip(prediction, target):
current_result = self.metric(prediction_str, target_str)
results.append(current_result)
return results
else:
return self.metric(prediction_prepared, target_prepared)
return self.metric(prediction, target)

def update(self, prediction: T, target: T) -> None:
prediction_prepared = self.prepare_function(prediction)
target_prepared = self.prepare_function(target)
if self.prepare_function is not None:
prediction = self.prepare_function(prediction)
target = self.prepare_function(target)
if self.prepare_both_function is not None:
prediction, target = self.prepare_both_function(prediction, target)
if self.prepare_does_unbatch:
if len(prediction_prepared) != len(target_prepared):
if not isinstance(prediction, Collection) or not isinstance(target, Collection):
raise ValueError(
"Both prediction and target need to be iterable and sized when prepare_does_unbatch=True."
)
if len(prediction) != len(target):
raise ValueError(
f"Number of prepared predictions ({len(prediction_prepared)}) and targets "
f"({len(target_prepared)}) do not match."
f"Number of prepared predictions ({len(prediction)}) and targets "
f"({len(target)}) do not match."
)
if len(prediction_prepared) == 0:
if len(prediction) == 0:
raise ValueError("Empty batch.")
for prediction_str, target_str in zip(prediction_prepared, target_prepared):
for prediction_str, target_str in zip(prediction, target):
self.metric.update(prediction_str, target_str)
else:
self.metric.update(prediction_prepared, target_prepared)
self.metric.update(prediction, target)

def compute(self) -> Any:
return self.metric.compute()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import pytest
from torchmetrics import Metric

Expand Down Expand Up @@ -48,6 +50,52 @@ def test_metric():
assert metric.compute() == 0.0


def split_both_and_remove_where_both_match(
preds: str, targets: str, match: str
) -> tuple[list[str], list[str]]:
preds = preds.split()
targets = targets.split()
not_both_none_indices = [
i for i, (p, t) in enumerate(zip(preds, targets)) if p != match or t != match
]
preds = [preds[i] for i in not_both_none_indices]
targets = [targets[i] for i in not_both_none_indices]
return preds, targets


def test_wrapped_metric_with_prepare_both_function():
metric = WrappedMetricWithPrepareFunction(
metric=TestMetric(),
prepare_together_function=partial(split_both_and_remove_where_both_match, match="none"),
prepare_does_unbatch=True,
)

assert metric is not None
assert metric.prepare_both_function is not None

assert metric.compute() == 0.0

# none is removed from both, remaining is the same
metric.reset()
metric(prediction="abc none", target="abc none")
assert metric.compute() == 1.0

# none is removed from both, remaining is different
metric.reset()
metric(prediction="abc none", target="def none")
assert metric.compute() == 0.0

# none is not removed from both, remaining is partially the same
metric.reset()
metric(prediction="abc def", target="abc none")
assert metric.compute() == 0.5

# none is not removed from both, remaining is different
metric.reset()
metric(prediction="abc def", target="def none")
assert metric.compute() == 0.0


@pytest.fixture(scope="module")
def wrapped_metric_with_unbatch_function():
# just split the strings to unbatch the inputs
Expand Down

0 comments on commit 218fc87

Please sign in to comment.