Skip to content

Commit

Permalink
Merge pull request #132 from ArneBinder/RE/fix_micro_f1_without_tn
Browse files Browse the repository at this point in the history
fix `micro/f1_without_tn` in RE taskmodule
  • Loading branch information
ArneBinder authored Oct 14, 2024
2 parents e8c7520 + 92cef1b commit 45d8db0
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections.abc import Collection
from collections.abc import Collection, Sized
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union

from torch import Tensor
Expand All @@ -9,6 +9,7 @@
logger = logging.getLogger(__name__)

T = TypeVar("T")
T2 = TypeVar("T2")


class WrappedMetricWithPrepareFunction(WrapperMetric, Generic[T]):
Expand Down Expand Up @@ -41,6 +42,22 @@ def __init__(
self.prepare_both_function = prepare_together_function
self.prepare_does_unbatch = prepare_does_unbatch

def _is_empty_batch(self, prediction: T2, target: T2) -> bool:
if isinstance(prediction, Sized) and isinstance(target, Sized):
pred_len = len(prediction)
target_len = len(target)
else:
raise ValueError(
"Both prediction and target need to be sized when prepare_does_unbatch=False."
)
if pred_len != target_len:
raise ValueError(
f"Number of elements in prediction ({pred_len}) and target ({target_len}) do not match."
)
if pred_len == 0:
return True
return False

def forward(self, prediction: T, target: T) -> Any:
if self.prepare_function is not None:
prediction = self.prepare_function(prediction)
Expand All @@ -65,7 +82,10 @@ def forward(self, prediction: T, target: T) -> Any:
results.append(current_result)
return results
else:
return self.metric(prediction, target)
if not self._is_empty_batch(prediction, target):
return self.metric(prediction, target)
else:
return None

def update(self, prediction: T, target: T) -> None:
if self.prepare_function is not None:
Expand All @@ -88,7 +108,8 @@ def update(self, prediction: T, target: T) -> None:
for prediction_str, target_str in zip(prediction, target):
self.metric.update(prediction_str, target_str)
else:
self.metric.update(prediction, target)
if not self._is_empty_batch(prediction, target):
self.metric.update(prediction, target)

def compute(self) -> Any:
return self.metric.compute()
Expand Down
38 changes: 38 additions & 0 deletions tests/taskmodules/test_re_text_classification_with_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,5 +1813,43 @@ def test_configure_model_metric(documents, taskmodule):
},
)

# no targets and no predictions
metric.reset()
no_targets = {"labels": torch.tensor([0, 0, 0])}
no_predictions = {"labels": torch.tensor([0, 0, 0])}
metric.update(no_targets, no_predictions)
state = get_metric_state(metric)

assert state == {
"micro/f1_without_tn/tp": [0],
"micro/f1_without_tn/fp": [0],
"micro/f1_without_tn/tn": [0],
"micro/f1_without_tn/fn": [0],
"with_tn/f1_per_label/tp": [3, 0, 0, 0],
"with_tn/f1_per_label/fp": [0, 0, 0, 0],
"with_tn/f1_per_label/tn": [0, 3, 3, 3],
"with_tn/f1_per_label/fn": [0, 0, 0, 0],
"with_tn/macro/f1/tp": [3, 0, 0, 0],
"with_tn/macro/f1/fp": [0, 0, 0, 0],
"with_tn/macro/f1/tn": [0, 3, 3, 3],
"with_tn/macro/f1/fn": [0, 0, 0, 0],
"with_tn/micro/f1/tp": [3],
"with_tn/micro/f1/fp": [0],
"with_tn/micro/f1/tn": [9],
"with_tn/micro/f1/fn": [0],
}
torch.testing.assert_close(
metric.compute(),
{
"micro/f1_without_tn": tensor(0.0),
"no_relation/f1": tensor(1.0),
"org:founded_by/f1": tensor(0.0),
"per:employee_of/f1": tensor(0.0),
"per:founder/f1": tensor(0.0),
"macro/f1": tensor(1.0),
"micro/f1": tensor(1.0),
},
)

# ensure that the metric can be pickled
pickle.dumps(metric)

0 comments on commit 45d8db0

Please sign in to comment.