diff --git a/tests/taskmodules/test_re_text_classification_with_indices.py b/tests/taskmodules/test_re_text_classification_with_indices.py index 435c43980..904046b43 100644 --- a/tests/taskmodules/test_re_text_classification_with_indices.py +++ b/tests/taskmodules/test_re_text_classification_with_indices.py @@ -1813,5 +1813,20 @@ def test_configure_model_metric(documents, taskmodule): }, ) + # no targets and no predictions + metric.reset() + no_targets = {"labels": torch.tensor([0, 0, 0])} + no_predictions = {"labels": torch.tensor([0, 0, 0])} + metric.update(no_targets, no_predictions) + state = get_metric_state(metric) + + # TODO: assert state + assert state + # TODO: assert metric.compute() + torch.testing.assert_close( + metric.compute(), + {}, + ) + # ensure that the metric can be pickled pickle.dumps(metric)