Skip to content

Commit

Permalink
check for emtpy batch via _is_empty_batch() if not prepare_does_unbatch
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Oct 14, 2024
1 parent e25276b commit 92cef1b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections.abc import Collection
from collections.abc import Collection, Sized
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union

from torch import Tensor
Expand All @@ -9,6 +9,7 @@
logger = logging.getLogger(__name__)

T = TypeVar("T")
T2 = TypeVar("T2")


class WrappedMetricWithPrepareFunction(WrapperMetric, Generic[T]):
Expand Down Expand Up @@ -41,6 +42,22 @@ def __init__(
self.prepare_both_function = prepare_together_function
self.prepare_does_unbatch = prepare_does_unbatch

def _is_empty_batch(self, prediction: T2, target: T2) -> bool:
if isinstance(prediction, Sized) and isinstance(target, Sized):
pred_len = len(prediction)
target_len = len(target)
else:
raise ValueError(
"Both prediction and target need to be sized when prepare_does_unbatch=False."
)
if pred_len != target_len:
raise ValueError(
f"Number of elements in prediction ({pred_len}) and target ({target_len}) do not match."
)
if pred_len == 0:
return True
return False

def forward(self, prediction: T, target: T) -> Any:
if self.prepare_function is not None:
prediction = self.prepare_function(prediction)
Expand All @@ -65,7 +82,10 @@ def forward(self, prediction: T, target: T) -> Any:
results.append(current_result)
return results
else:
return self.metric(prediction, target)
if not self._is_empty_batch(prediction, target):
return self.metric(prediction, target)
else:
return None

def update(self, prediction: T, target: T) -> None:
if self.prepare_function is not None:
Expand All @@ -88,7 +108,8 @@ def update(self, prediction: T, target: T) -> None:
for prediction_str, target_str in zip(prediction, target):
self.metric.update(prediction_str, target_str)
else:
self.metric.update(prediction, target)
if not self._is_empty_batch(prediction, target):
self.metric.update(prediction, target)

def compute(self) -> Any:
return self.metric.compute()
Expand Down
31 changes: 27 additions & 4 deletions tests/taskmodules/test_re_text_classification_with_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,12 +1820,35 @@ def test_configure_model_metric(documents, taskmodule):
metric.update(no_targets, no_predictions)
state = get_metric_state(metric)

# TODO: assert state
assert state
# TODO: assert metric.compute()
assert state == {
"micro/f1_without_tn/tp": [0],
"micro/f1_without_tn/fp": [0],
"micro/f1_without_tn/tn": [0],
"micro/f1_without_tn/fn": [0],
"with_tn/f1_per_label/tp": [3, 0, 0, 0],
"with_tn/f1_per_label/fp": [0, 0, 0, 0],
"with_tn/f1_per_label/tn": [0, 3, 3, 3],
"with_tn/f1_per_label/fn": [0, 0, 0, 0],
"with_tn/macro/f1/tp": [3, 0, 0, 0],
"with_tn/macro/f1/fp": [0, 0, 0, 0],
"with_tn/macro/f1/tn": [0, 3, 3, 3],
"with_tn/macro/f1/fn": [0, 0, 0, 0],
"with_tn/micro/f1/tp": [3],
"with_tn/micro/f1/fp": [0],
"with_tn/micro/f1/tn": [9],
"with_tn/micro/f1/fn": [0],
}
torch.testing.assert_close(
metric.compute(),
{},
{
"micro/f1_without_tn": tensor(0.0),
"no_relation/f1": tensor(1.0),
"org:founded_by/f1": tensor(0.0),
"per:employee_of/f1": tensor(0.0),
"per:founder/f1": tensor(0.0),
"macro/f1": tensor(1.0),
"micro/f1": tensor(1.0),
},
)

# ensure that the metric can be pickled
Expand Down

0 comments on commit 92cef1b

Please sign in to comment.