diff --git a/src/pie_modules/taskmodules/re_text_classification_with_indices.py b/src/pie_modules/taskmodules/re_text_classification_with_indices.py index ad29c23b1..2d9df03b5 100644 --- a/src/pie_modules/taskmodules/re_text_classification_with_indices.py +++ b/src/pie_modules/taskmodules/re_text_classification_with_indices.py @@ -979,15 +979,18 @@ def configure_model_metric(self, stage: str) -> Metric: ) # we use the length of label_to_id because that contains the none_label (in contrast to labels) labels = [self.id_to_label[i] for i in range(len(self.label_to_id))] - num_classes = len(labels) - task = "multilabel" if self.multi_label else "multiclass" + common_metric_kwargs = { + "num_classes": len(labels), + "task": "multilabel" if self.multi_label else "multiclass", + "ignore_index": self.label_to_id[self.none_label], + } return WrappedMetricWithPrepareFunction( metric=MetricCollection( { - "micro/f1": F1Score(num_classes=num_classes, task=task, average="micro"), - "macro/f1": F1Score(num_classes=num_classes, task=task, average="macro"), + "micro/f1": F1Score(average="micro", **common_metric_kwargs), + "macro/f1": F1Score(average="macro", **common_metric_kwargs), "f1_per_label": ClasswiseWrapper( - F1Score(num_classes=num_classes, task=task, average=None), + F1Score(average=None, **common_metric_kwargs), labels=labels, postfix="/f1", ),