From e25276b14cad98cb8d38a9dd39728b3967382ea5 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 14 Oct 2024 17:25:23 +0200 Subject: [PATCH 1/2] add test for metrics without valid input --- .../test_re_text_classification_with_indices.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/taskmodules/test_re_text_classification_with_indices.py b/tests/taskmodules/test_re_text_classification_with_indices.py index 435c43980..904046b43 100644 --- a/tests/taskmodules/test_re_text_classification_with_indices.py +++ b/tests/taskmodules/test_re_text_classification_with_indices.py @@ -1813,5 +1813,20 @@ def test_configure_model_metric(documents, taskmodule): }, ) + # no targets and no predictions + metric.reset() + no_targets = {"labels": torch.tensor([0, 0, 0])} + no_predictions = {"labels": torch.tensor([0, 0, 0])} + metric.update(no_targets, no_predictions) + state = get_metric_state(metric) + + # TODO: assert state + assert state + # TODO: assert metric.compute() + torch.testing.assert_close( + metric.compute(), + {}, + ) + # ensure that the metric can be pickled pickle.dumps(metric) From 92cef1b54938f2bcab8c82801a92e11fe897c0d0 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 14 Oct 2024 18:24:55 +0200 Subject: [PATCH 2/2] check for emtpy batch via _is_empty_batch() if not prepare_does_unbatch --- .../wrapped_metric_with_prepare_function.py | 27 ++++++++++++++-- ...est_re_text_classification_with_indices.py | 31 ++++++++++++++++--- 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py b/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py index 86acdba94..7daceb9dd 100644 --- a/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py +++ b/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Collection +from collections.abc import Collection, Sized from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union from torch import Tensor @@ -9,6 +9,7 @@ logger = logging.getLogger(__name__) T = TypeVar("T") +T2 = TypeVar("T2") class WrappedMetricWithPrepareFunction(WrapperMetric, Generic[T]): @@ -41,6 +42,22 @@ def __init__( self.prepare_both_function = prepare_together_function self.prepare_does_unbatch = prepare_does_unbatch + def _is_empty_batch(self, prediction: T2, target: T2) -> bool: + if isinstance(prediction, Sized) and isinstance(target, Sized): + pred_len = len(prediction) + target_len = len(target) + else: + raise ValueError( + "Both prediction and target need to be sized when prepare_does_unbatch=False." + ) + if pred_len != target_len: + raise ValueError( + f"Number of elements in prediction ({pred_len}) and target ({target_len}) do not match." + ) + if pred_len == 0: + return True + return False + def forward(self, prediction: T, target: T) -> Any: if self.prepare_function is not None: prediction = self.prepare_function(prediction) @@ -65,7 +82,10 @@ def forward(self, prediction: T, target: T) -> Any: results.append(current_result) return results else: - return self.metric(prediction, target) + if not self._is_empty_batch(prediction, target): + return self.metric(prediction, target) + else: + return None def update(self, prediction: T, target: T) -> None: if self.prepare_function is not None: @@ -88,7 +108,8 @@ def update(self, prediction: T, target: T) -> None: for prediction_str, target_str in zip(prediction, target): self.metric.update(prediction_str, target_str) else: - self.metric.update(prediction, target) + if not self._is_empty_batch(prediction, target): + self.metric.update(prediction, target) def compute(self) -> Any: return self.metric.compute() diff --git a/tests/taskmodules/test_re_text_classification_with_indices.py b/tests/taskmodules/test_re_text_classification_with_indices.py index 904046b43..e5d2a03ee 100644 --- a/tests/taskmodules/test_re_text_classification_with_indices.py +++ b/tests/taskmodules/test_re_text_classification_with_indices.py @@ -1820,12 +1820,35 @@ def test_configure_model_metric(documents, taskmodule): metric.update(no_targets, no_predictions) state = get_metric_state(metric) - # TODO: assert state - assert state - # TODO: assert metric.compute() + assert state == { + "micro/f1_without_tn/tp": [0], + "micro/f1_without_tn/fp": [0], + "micro/f1_without_tn/tn": [0], + "micro/f1_without_tn/fn": [0], + "with_tn/f1_per_label/tp": [3, 0, 0, 0], + "with_tn/f1_per_label/fp": [0, 0, 0, 0], + "with_tn/f1_per_label/tn": [0, 3, 3, 3], + "with_tn/f1_per_label/fn": [0, 0, 0, 0], + "with_tn/macro/f1/tp": [3, 0, 0, 0], + "with_tn/macro/f1/fp": [0, 0, 0, 0], + "with_tn/macro/f1/tn": [0, 3, 3, 3], + "with_tn/macro/f1/fn": [0, 0, 0, 0], + "with_tn/micro/f1/tp": [3], + "with_tn/micro/f1/fp": [0], + "with_tn/micro/f1/tn": [9], + "with_tn/micro/f1/fn": [0], + } torch.testing.assert_close( metric.compute(), - {}, + { + "micro/f1_without_tn": tensor(0.0), + "no_relation/f1": tensor(1.0), + "org:founded_by/f1": tensor(0.0), + "per:employee_of/f1": tensor(0.0), + "per:founder/f1": tensor(0.0), + "macro/f1": tensor(1.0), + "micro/f1": tensor(1.0), + }, ) # ensure that the metric can be pickled