From aa384dd7ace3c125ec4a35b8bfebe3872b759c5d Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 23 Jan 2024 18:42:22 +0100 Subject: [PATCH 1/2] exclude "O" for token based metrics from LabeledSpanExtractionByTokenClassificationTaskModule --- ...d_span_extraction_by_token_classification.py | 17 +++++++---------- .../models/test_simple_token_classification.py | 8 ++++---- ...assification_with_seq2seq_encoder_and_crf.py | 16 ++++++++-------- ...d_span_extraction_by_token_classification.py | 4 ++-- 4 files changed, 21 insertions(+), 24 deletions(-) diff --git a/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py b/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py index 4414d3333..ecb5f1e7d 100644 --- a/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py +++ b/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py @@ -406,22 +406,19 @@ def create_annotations_from_output( yield self.span_annotation, span.copy() def configure_model_metric(self, stage: str) -> Union[Metric, MetricCollection]: + common_metric_kwargs = { + "num_classes": len(self.label_to_id), + "task": "multiclass", + "ignore_index": self.label_to_id["O"], + } token_scores = MetricCollection( { "token/macro/f1": WrappedMetricWithPrepareFunction( - metric=F1Score( - num_classes=len(self.label_to_id), - task="multiclass", - average="macro", - ), + metric=F1Score(average="macro", **common_metric_kwargs), prepare_function=partial(remove_label_pad_ids, label_pad_id=self.label_pad_id), ), "token/micro/f1": WrappedMetricWithPrepareFunction( - metric=F1Score( - num_classes=len(self.label_to_id), - task="multiclass", - average="micro", - ), + metric=F1Score(average="micro", **common_metric_kwargs), prepare_function=partial(remove_label_pad_ids, label_pad_id=self.label_pad_id), ), } diff --git a/tests/models/test_simple_token_classification.py b/tests/models/test_simple_token_classification.py index dafd4e3dd..af1a926e2 100644 --- a/tests/models/test_simple_token_classification.py +++ b/tests/models/test_simple_token_classification.py @@ -292,8 +292,8 @@ def test_validation_step_and_on_epoch_end(batch, model, config): "span/micro/f1": 0.13793103396892548, "span/micro/precision": 0.3333333432674408, "span/micro/recall": 0.08695652335882187, - "token/macro/f1": 0.04210526496171951, - "token/micro/f1": 0.06896551698446274, + "token/macro/f1": 0.08888889104127884, + "token/micro/f1": 0.13333334028720856, } model.on_validation_epoch_end() @@ -319,8 +319,8 @@ def test_test_step_and_on_epoch_end(batch, model, config): "span/micro/f1": 0.13793103396892548, "span/micro/precision": 0.3333333432674408, "span/micro/recall": 0.08695652335882187, - "token/macro/f1": 0.04210526496171951, - "token/micro/f1": 0.06896551698446274, + "token/macro/f1": 0.08888889104127884, + "token/micro/f1": 0.13333334028720856, } model.on_test_epoch_end() diff --git a/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py b/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py index 2dff99f92..1de53727c 100644 --- a/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py +++ b/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py @@ -353,8 +353,8 @@ def test_validation_step_and_on_epoch_end(batch, model, config): if config == {}: torch.testing.assert_close(loss, torch.tensor(59.42658996582031)) assert metric_values == { - "token/macro/f1": 0.19285714626312256, - "token/micro/f1": 0.27586206793785095, + "token/macro/f1": 0.3919413983821869, + "token/micro/f1": 0.5333333611488342, "span/PER/f1": 0.0833333358168602, "span/PER/recall": 0.0476190485060215, "span/PER/precision": 0.3333333432674408, @@ -371,8 +371,8 @@ def test_validation_step_and_on_epoch_end(batch, model, config): elif config == {"use_crf": False}: torch.testing.assert_close(loss, torch.tensor(1.6708829402923584)) assert metric_values == { - "token/macro/f1": 0.08615384995937347, - "token/micro/f1": 0.13793103396892548, + "token/macro/f1": 0.14374999701976776, + "token/micro/f1": 0.2666666805744171, "span/PER/f1": 0.0, "span/PER/recall": 0.0, "span/PER/precision": 0.0, @@ -401,8 +401,8 @@ def test_test_step_and_on_epoch_end(batch, model, config): if config == {}: torch.testing.assert_close(loss, torch.tensor(59.42658996582031)) assert metric_values == { - "token/macro/f1": 0.19285714626312256, - "token/micro/f1": 0.27586206793785095, + "token/macro/f1": 0.3919413983821869, + "token/micro/f1": 0.5333333611488342, "span/ORG/f1": 0.0, "span/ORG/recall": 0.0, "span/ORG/precision": 0.0, @@ -419,8 +419,8 @@ def test_test_step_and_on_epoch_end(batch, model, config): elif config == {"use_crf": False}: torch.testing.assert_close(loss, torch.tensor(1.6708829402923584)) assert metric_values == { - "token/macro/f1": 0.08615384995937347, - "token/micro/f1": 0.13793103396892548, + "token/macro/f1": 0.14374999701976776, + "token/micro/f1": 0.2666666805744171, "span/ORG/f1": 0.0, "span/ORG/recall": 0.0, "span/ORG/precision": 0.0, diff --git a/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py b/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py index 99b7a5760..7c8da2e23 100644 --- a/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py +++ b/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py @@ -787,8 +787,8 @@ def test_configure_model_metric(documents): values = metric.compute() values_converted = {k: v.item() for k, v in values.items()} assert values_converted == { - "token/macro/f1": 0.5434783101081848, - "token/micro/f1": 0.5249999761581421, + "token/macro/f1": 0.6349206566810608, + "token/micro/f1": 0.625, "span/LOC/recall": 0.0476190485060215, "span/LOC/precision": 0.5, "span/LOC/f1": 0.08695652335882187, From 1ebb9472a496dfad9c94765b26a0f1d795a1db07 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 23 Jan 2024 18:43:01 +0100 Subject: [PATCH 2/2] exclude none_label for metrics from RETextClassificationWithIndicesTaskModule --- .../re_text_classification_with_indices.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/pie_modules/taskmodules/re_text_classification_with_indices.py b/src/pie_modules/taskmodules/re_text_classification_with_indices.py index ad29c23b1..2d9df03b5 100644 --- a/src/pie_modules/taskmodules/re_text_classification_with_indices.py +++ b/src/pie_modules/taskmodules/re_text_classification_with_indices.py @@ -979,15 +979,18 @@ def configure_model_metric(self, stage: str) -> Metric: ) # we use the length of label_to_id because that contains the none_label (in contrast to labels) labels = [self.id_to_label[i] for i in range(len(self.label_to_id))] - num_classes = len(labels) - task = "multilabel" if self.multi_label else "multiclass" + common_metric_kwargs = { + "num_classes": len(labels), + "task": "multilabel" if self.multi_label else "multiclass", + "ignore_index": self.label_to_id[self.none_label], + } return WrappedMetricWithPrepareFunction( metric=MetricCollection( { - "micro/f1": F1Score(num_classes=num_classes, task=task, average="micro"), - "macro/f1": F1Score(num_classes=num_classes, task=task, average="macro"), + "micro/f1": F1Score(average="micro", **common_metric_kwargs), + "macro/f1": F1Score(average="macro", **common_metric_kwargs), "f1_per_label": ClasswiseWrapper( - F1Score(num_classes=num_classes, task=task, average=None), + F1Score(average=None, **common_metric_kwargs), labels=labels, postfix="/f1", ),