diff --git a/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py b/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py index 86acdba94..7daceb9dd 100644 --- a/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py +++ b/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Collection +from collections.abc import Collection, Sized from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union from torch import Tensor @@ -9,6 +9,7 @@ logger = logging.getLogger(__name__) T = TypeVar("T") +T2 = TypeVar("T2") class WrappedMetricWithPrepareFunction(WrapperMetric, Generic[T]): @@ -41,6 +42,22 @@ def __init__( self.prepare_both_function = prepare_together_function self.prepare_does_unbatch = prepare_does_unbatch + def _is_empty_batch(self, prediction: T2, target: T2) -> bool: + if isinstance(prediction, Sized) and isinstance(target, Sized): + pred_len = len(prediction) + target_len = len(target) + else: + raise ValueError( + "Both prediction and target need to be sized when prepare_does_unbatch=False." + ) + if pred_len != target_len: + raise ValueError( + f"Number of elements in prediction ({pred_len}) and target ({target_len}) do not match." + ) + if pred_len == 0: + return True + return False + def forward(self, prediction: T, target: T) -> Any: if self.prepare_function is not None: prediction = self.prepare_function(prediction) @@ -65,7 +82,10 @@ def forward(self, prediction: T, target: T) -> Any: results.append(current_result) return results else: - return self.metric(prediction, target) + if not self._is_empty_batch(prediction, target): + return self.metric(prediction, target) + else: + return None def update(self, prediction: T, target: T) -> None: if self.prepare_function is not None: @@ -88,7 +108,8 @@ def update(self, prediction: T, target: T) -> None: for prediction_str, target_str in zip(prediction, target): self.metric.update(prediction_str, target_str) else: - self.metric.update(prediction, target) + if not self._is_empty_batch(prediction, target): + self.metric.update(prediction, target) def compute(self) -> Any: return self.metric.compute() diff --git a/tests/taskmodules/test_re_text_classification_with_indices.py b/tests/taskmodules/test_re_text_classification_with_indices.py index 435c43980..e5d2a03ee 100644 --- a/tests/taskmodules/test_re_text_classification_with_indices.py +++ b/tests/taskmodules/test_re_text_classification_with_indices.py @@ -1813,5 +1813,43 @@ def test_configure_model_metric(documents, taskmodule): }, ) + # no targets and no predictions + metric.reset() + no_targets = {"labels": torch.tensor([0, 0, 0])} + no_predictions = {"labels": torch.tensor([0, 0, 0])} + metric.update(no_targets, no_predictions) + state = get_metric_state(metric) + + assert state == { + "micro/f1_without_tn/tp": [0], + "micro/f1_without_tn/fp": [0], + "micro/f1_without_tn/tn": [0], + "micro/f1_without_tn/fn": [0], + "with_tn/f1_per_label/tp": [3, 0, 0, 0], + "with_tn/f1_per_label/fp": [0, 0, 0, 0], + "with_tn/f1_per_label/tn": [0, 3, 3, 3], + "with_tn/f1_per_label/fn": [0, 0, 0, 0], + "with_tn/macro/f1/tp": [3, 0, 0, 0], + "with_tn/macro/f1/fp": [0, 0, 0, 0], + "with_tn/macro/f1/tn": [0, 3, 3, 3], + "with_tn/macro/f1/fn": [0, 0, 0, 0], + "with_tn/micro/f1/tp": [3], + "with_tn/micro/f1/fp": [0], + "with_tn/micro/f1/tn": [9], + "with_tn/micro/f1/fn": [0], + } + torch.testing.assert_close( + metric.compute(), + { + "micro/f1_without_tn": tensor(0.0), + "no_relation/f1": tensor(1.0), + "org:founded_by/f1": tensor(0.0), + "per:employee_of/f1": tensor(0.0), + "per:founder/f1": tensor(0.0), + "macro/f1": tensor(1.0), + "micro/f1": tensor(1.0), + }, + ) + # ensure that the metric can be pickled pickle.dumps(metric)