From a88f44fb05f73e76e4b58b029662053444519d1a Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 8 Oct 2024 19:27:16 +0200 Subject: [PATCH] add metric "micro/f1_without_tn" to RETextClassificationWithIndicesTaskModule --- .../re_text_classification_with_indices.py | 49 ++++++-- ...est_re_text_classification_with_indices.py | 116 ++++++++++-------- 2 files changed, 103 insertions(+), 62 deletions(-) diff --git a/src/pie_modules/taskmodules/re_text_classification_with_indices.py b/src/pie_modules/taskmodules/re_text_classification_with_indices.py index 7483c1ca9..838100e5f 100644 --- a/src/pie_modules/taskmodules/re_text_classification_with_indices.py +++ b/src/pie_modules/taskmodules/re_text_classification_with_indices.py @@ -9,6 +9,7 @@ import logging from collections import defaultdict +from functools import partial from typing import ( Any, Dict, @@ -109,6 +110,15 @@ def _get_labels(model_output: ModelTargetType) -> LongTensor: return model_output["labels"] +def _get_labels_together_remove_none_label( + predictions: ModelTargetType, targets: ModelTargetType, none_idx: int +) -> Tuple[LongTensor, LongTensor]: + mask_not_both_none = (predictions["labels"] != none_idx) | (targets["labels"] != none_idx) + predictions_not_none = predictions["labels"][mask_not_both_none] + targets_not_none = targets["labels"][mask_not_both_none] + return predictions_not_none, targets_not_none + + def inner_span_distance(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> int: dist_start_other_end = abs(start_end[0] - other_start_end[1]) dist_end_other_start = abs(start_end[1] - other_start_end[0]) @@ -1066,7 +1076,7 @@ def collate( return inputs, {"labels": targets} - def configure_model_metric(self, stage: str) -> Metric: + def configure_model_metric(self, stage: str) -> MetricCollection: if self.label_to_id is None: raise ValueError( "The taskmodule has not been prepared yet, so label_to_id is not known. " @@ -1079,17 +1089,30 @@ def configure_model_metric(self, stage: str) -> Metric: "num_classes": len(labels), "task": "multilabel" if self.multi_label else "multiclass", } - return WrappedMetricWithPrepareFunction( - metric=MetricCollection( - { - "micro/f1": F1Score(average="micro", **common_metric_kwargs), - "macro/f1": F1Score(average="macro", **common_metric_kwargs), - "f1_per_label": ClasswiseWrapper( - F1Score(average=None, **common_metric_kwargs), - labels=labels, - postfix="/f1", + return MetricCollection( + { + "with_tn": WrappedMetricWithPrepareFunction( + metric=MetricCollection( + { + "micro/f1": F1Score(average="micro", **common_metric_kwargs), + "macro/f1": F1Score(average="macro", **common_metric_kwargs), + "f1_per_label": ClasswiseWrapper( + F1Score(average=None, **common_metric_kwargs), + labels=labels, + postfix="/f1", + ), + } ), - } - ), - prepare_function=_get_labels, + prepare_function=_get_labels, + ), + # We can not easily calculate the macro f1 here, because + # F1Score with average="macro" would still include the none_label. + "micro/f1_without_tn": WrappedMetricWithPrepareFunction( + metric=F1Score(average="micro", **common_metric_kwargs), + prepare_together_function=partial( + _get_labels_together_remove_none_label, + none_idx=self.label_to_id[self.none_label], + ), + ), + } ) diff --git a/tests/taskmodules/test_re_text_classification_with_indices.py b/tests/taskmodules/test_re_text_classification_with_indices.py index 38086092f..435c43980 100644 --- a/tests/taskmodules/test_re_text_classification_with_indices.py +++ b/tests/taskmodules/test_re_text_classification_with_indices.py @@ -1717,18 +1717,22 @@ def test_configure_model_metric(documents, taskmodule): assert isinstance(metric, (Metric, MetricCollection)) state = get_metric_state(metric) assert state == { - "f1_per_label/tp": [0, 0, 0, 0], - "f1_per_label/fp": [0, 0, 0, 0], - "f1_per_label/tn": [0, 0, 0, 0], - "f1_per_label/fn": [0, 0, 0, 0], - "macro/f1/fn": [0, 0, 0, 0], - "macro/f1/fp": [0, 0, 0, 0], - "macro/f1/tn": [0, 0, 0, 0], - "macro/f1/tp": [0, 0, 0, 0], - "micro/f1/fn": [0], - "micro/f1/fp": [0], - "micro/f1/tn": [0], - "micro/f1/tp": [0], + "micro/f1_without_tn/tp": [0], + "micro/f1_without_tn/fp": [0], + "micro/f1_without_tn/tn": [0], + "micro/f1_without_tn/fn": [0], + "with_tn/f1_per_label/tp": [0, 0, 0, 0], + "with_tn/f1_per_label/fp": [0, 0, 0, 0], + "with_tn/f1_per_label/tn": [0, 0, 0, 0], + "with_tn/f1_per_label/fn": [0, 0, 0, 0], + "with_tn/macro/f1/tp": [0, 0, 0, 0], + "with_tn/macro/f1/fp": [0, 0, 0, 0], + "with_tn/macro/f1/tn": [0, 0, 0, 0], + "with_tn/macro/f1/fn": [0, 0, 0, 0], + "with_tn/micro/f1/tp": [0], + "with_tn/micro/f1/fp": [0], + "with_tn/micro/f1/tn": [0], + "with_tn/micro/f1/fn": [0], } assert metric.compute() == { "no_relation/f1": tensor(0.0), @@ -1737,24 +1741,29 @@ def test_configure_model_metric(documents, taskmodule): "per:founder/f1": tensor(0.0), "macro/f1": tensor(0.0), "micro/f1": tensor(0.0), + "micro/f1_without_tn": tensor(0.0), } targets = batch[1] metric.update(targets, targets) state = get_metric_state(metric) assert state == { - "f1_per_label/fn": [0, 0, 0, 0], - "f1_per_label/fp": [0, 0, 0, 0], - "f1_per_label/tn": [7, 5, 4, 5], - "f1_per_label/tp": [0, 2, 3, 2], - "macro/f1/fn": [0, 0, 0, 0], - "macro/f1/fp": [0, 0, 0, 0], - "macro/f1/tn": [7, 5, 4, 5], - "macro/f1/tp": [0, 2, 3, 2], - "micro/f1/fn": [0], - "micro/f1/fp": [0], - "micro/f1/tn": [21], - "micro/f1/tp": [7], + "micro/f1_without_tn/tp": [7], + "micro/f1_without_tn/fp": [0], + "micro/f1_without_tn/tn": [21], + "micro/f1_without_tn/fn": [0], + "with_tn/f1_per_label/tp": [0, 2, 3, 2], + "with_tn/f1_per_label/fp": [0, 0, 0, 0], + "with_tn/f1_per_label/tn": [7, 5, 4, 5], + "with_tn/f1_per_label/fn": [0, 0, 0, 0], + "with_tn/macro/f1/tp": [0, 2, 3, 2], + "with_tn/macro/f1/fp": [0, 0, 0, 0], + "with_tn/macro/f1/tn": [7, 5, 4, 5], + "with_tn/macro/f1/fn": [0, 0, 0, 0], + "with_tn/micro/f1/tp": [7], + "with_tn/micro/f1/fp": [0], + "with_tn/micro/f1/tn": [21], + "with_tn/micro/f1/fn": [0], } assert metric.compute() == { "no_relation/f1": tensor(0.0), @@ -1763,37 +1772,46 @@ def test_configure_model_metric(documents, taskmodule): "per:founder/f1": tensor(1.0), "macro/f1": tensor(1.0), "micro/f1": tensor(1.0), + "micro/f1_without_tn": tensor(1.0), } metric.reset() - torch.testing.assert_close(targets["labels"], torch.tensor([2, 2, 3, 1, 2, 3, 1])) - # three matches - random_targets = {"labels": torch.tensor([1, 1, 3, 1, 2, 0, 0])} - metric.update(random_targets, targets) + modified_targets = {"labels": torch.tensor([2, 2, 3, 1, 2, 0, 1])} + # three positive matches and one true negative + random_predictions = {"labels": torch.tensor([1, 1, 3, 1, 2, 0, 0])} + metric.update(random_predictions, modified_targets) state = get_metric_state(metric) assert state == { - "f1_per_label/fn": [0, 1, 2, 1], - "f1_per_label/fp": [2, 2, 0, 0], - "f1_per_label/tn": [5, 3, 4, 5], - "f1_per_label/tp": [0, 1, 1, 1], - "macro/f1/fn": [0, 1, 2, 1], - "macro/f1/fp": [2, 2, 0, 0], - "macro/f1/tn": [5, 3, 4, 5], - "macro/f1/tp": [0, 1, 1, 1], - "micro/f1/fn": [4], - "micro/f1/fp": [4], - "micro/f1/tn": [17], - "micro/f1/tp": [3], - } - values = {k: v.item() for k, v in metric.compute().items()} - assert values == { - "macro/f1": 0.3916666507720947, - "micro/f1": 0.4285714328289032, - "no_relation/f1": 0.0, - "org:founded_by/f1": 0.4000000059604645, - "per:employee_of/f1": 0.5, - "per:founder/f1": 0.6666666865348816, + "micro/f1_without_tn/tp": [3], + "micro/f1_without_tn/fp": [3], + "micro/f1_without_tn/tn": [15], + "micro/f1_without_tn/fn": [3], + "with_tn/f1_per_label/tp": [1, 1, 1, 1], + "with_tn/f1_per_label/fp": [1, 2, 0, 0], + "with_tn/f1_per_label/tn": [5, 3, 4, 6], + "with_tn/f1_per_label/fn": [0, 1, 2, 0], + "with_tn/macro/f1/tp": [1, 1, 1, 1], + "with_tn/macro/f1/fp": [1, 2, 0, 0], + "with_tn/macro/f1/tn": [5, 3, 4, 6], + "with_tn/macro/f1/fn": [0, 1, 2, 0], + "with_tn/micro/f1/tp": [4], + "with_tn/micro/f1/fp": [3], + "with_tn/micro/f1/tn": [18], + "with_tn/micro/f1/fn": [3], } + # created with torch.set_printoptions(precision=6) + torch.testing.assert_close( + metric.compute(), + { + "no_relation/f1": tensor(0.666667), + "org:founded_by/f1": tensor(0.400000), + "per:employee_of/f1": tensor(0.500000), + "per:founder/f1": tensor(1.0), + "macro/f1": tensor(0.641667), + "micro/f1": tensor(0.571429), + "micro/f1_without_tn": tensor(0.500000), + }, + ) # ensure that the metric can be pickled pickle.dumps(metric)