Skip to content

Commit

Permalink
exclude none_label for metrics from RETextClassificationWithIndicesTa…
Browse files Browse the repository at this point in the history
…skModule
  • Loading branch information
ArneBinder committed Jan 31, 2024
1 parent aa384dd commit 1ebb947
Showing 1 changed file with 8 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -979,15 +979,18 @@ def configure_model_metric(self, stage: str) -> Metric:
)
# we use the length of label_to_id because that contains the none_label (in contrast to labels)
labels = [self.id_to_label[i] for i in range(len(self.label_to_id))]
num_classes = len(labels)
task = "multilabel" if self.multi_label else "multiclass"
common_metric_kwargs = {
"num_classes": len(labels),
"task": "multilabel" if self.multi_label else "multiclass",
"ignore_index": self.label_to_id[self.none_label],
}
return WrappedMetricWithPrepareFunction(
metric=MetricCollection(
{
"micro/f1": F1Score(num_classes=num_classes, task=task, average="micro"),
"macro/f1": F1Score(num_classes=num_classes, task=task, average="macro"),
"micro/f1": F1Score(average="micro", **common_metric_kwargs),
"macro/f1": F1Score(average="macro", **common_metric_kwargs),
"f1_per_label": ClasswiseWrapper(
F1Score(num_classes=num_classes, task=task, average=None),
F1Score(average=None, **common_metric_kwargs),
labels=labels,
postfix="/f1",
),
Expand Down

0 comments on commit 1ebb947

Please sign in to comment.