Skip to content

Commit

Permalink
Merge pull request #130 from ArneBinder/RE/micro_f1_without_tn
Browse files Browse the repository at this point in the history
add metric `micro/f1_without_tn` to `RETextClassificationWithIndicesTaskmodule`
  • Loading branch information
ArneBinder authored Oct 8, 2024
2 parents 218fc87 + 30f7e2b commit f02c9f7
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 65 deletions.
53 changes: 37 additions & 16 deletions src/pie_modules/taskmodules/re_text_classification_with_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import logging
from collections import defaultdict
from functools import partial
from typing import (
Any,
Dict,
Expand All @@ -25,7 +25,6 @@
)

import numpy as np
import pandas as pd
import torch
from pytorch_ie.annotations import (
BinaryRelation,
Expand All @@ -50,7 +49,7 @@
from pytorch_ie.utils.span import has_overlap, is_contained_in
from pytorch_ie.utils.window import get_window_around_slice
from torch import LongTensor
from torchmetrics import ClasswiseWrapper, F1Score, Metric, MetricCollection
from torchmetrics import ClasswiseWrapper, F1Score, MetricCollection
from transformers import AutoTokenizer
from transformers.file_utils import PaddingStrategy
from transformers.tokenization_utils_base import TruncationStrategy
Expand Down Expand Up @@ -109,6 +108,15 @@ def _get_labels(model_output: ModelTargetType) -> LongTensor:
return model_output["labels"]


def _get_labels_together_remove_none_label(
predictions: ModelTargetType, targets: ModelTargetType, none_idx: int
) -> Tuple[LongTensor, LongTensor]:
mask_not_both_none = (predictions["labels"] != none_idx) | (targets["labels"] != none_idx)
predictions_not_none = predictions["labels"][mask_not_both_none]
targets_not_none = targets["labels"][mask_not_both_none]
return predictions_not_none, targets_not_none


def inner_span_distance(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> int:
dist_start_other_end = abs(start_end[0] - other_start_end[1])
dist_end_other_start = abs(start_end[1] - other_start_end[0])
Expand Down Expand Up @@ -1066,7 +1074,7 @@ def collate(

return inputs, {"labels": targets}

def configure_model_metric(self, stage: str) -> Metric:
def configure_model_metric(self, stage: str) -> MetricCollection:
if self.label_to_id is None:
raise ValueError(
"The taskmodule has not been prepared yet, so label_to_id is not known. "
Expand All @@ -1079,17 +1087,30 @@ def configure_model_metric(self, stage: str) -> Metric:
"num_classes": len(labels),
"task": "multilabel" if self.multi_label else "multiclass",
}
return WrappedMetricWithPrepareFunction(
metric=MetricCollection(
{
"micro/f1": F1Score(average="micro", **common_metric_kwargs),
"macro/f1": F1Score(average="macro", **common_metric_kwargs),
"f1_per_label": ClasswiseWrapper(
F1Score(average=None, **common_metric_kwargs),
labels=labels,
postfix="/f1",
return MetricCollection(
{
"with_tn": WrappedMetricWithPrepareFunction(
metric=MetricCollection(
{
"micro/f1": F1Score(average="micro", **common_metric_kwargs),
"macro/f1": F1Score(average="macro", **common_metric_kwargs),
"f1_per_label": ClasswiseWrapper(
F1Score(average=None, **common_metric_kwargs),
labels=labels,
postfix="/f1",
),
}
),
}
),
prepare_function=_get_labels,
prepare_function=_get_labels,
),
# We can not easily calculate the macro f1 here, because
# F1Score with average="macro" would still include the none_label.
"micro/f1_without_tn": WrappedMetricWithPrepareFunction(
metric=F1Score(average="micro", **common_metric_kwargs),
prepare_together_function=partial(
_get_labels_together_remove_none_label,
none_idx=self.label_to_id[self.none_label],
),
),
}
)
116 changes: 67 additions & 49 deletions tests/taskmodules/test_re_text_classification_with_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,18 +1717,22 @@ def test_configure_model_metric(documents, taskmodule):
assert isinstance(metric, (Metric, MetricCollection))
state = get_metric_state(metric)
assert state == {
"f1_per_label/tp": [0, 0, 0, 0],
"f1_per_label/fp": [0, 0, 0, 0],
"f1_per_label/tn": [0, 0, 0, 0],
"f1_per_label/fn": [0, 0, 0, 0],
"macro/f1/fn": [0, 0, 0, 0],
"macro/f1/fp": [0, 0, 0, 0],
"macro/f1/tn": [0, 0, 0, 0],
"macro/f1/tp": [0, 0, 0, 0],
"micro/f1/fn": [0],
"micro/f1/fp": [0],
"micro/f1/tn": [0],
"micro/f1/tp": [0],
"micro/f1_without_tn/tp": [0],
"micro/f1_without_tn/fp": [0],
"micro/f1_without_tn/tn": [0],
"micro/f1_without_tn/fn": [0],
"with_tn/f1_per_label/tp": [0, 0, 0, 0],
"with_tn/f1_per_label/fp": [0, 0, 0, 0],
"with_tn/f1_per_label/tn": [0, 0, 0, 0],
"with_tn/f1_per_label/fn": [0, 0, 0, 0],
"with_tn/macro/f1/tp": [0, 0, 0, 0],
"with_tn/macro/f1/fp": [0, 0, 0, 0],
"with_tn/macro/f1/tn": [0, 0, 0, 0],
"with_tn/macro/f1/fn": [0, 0, 0, 0],
"with_tn/micro/f1/tp": [0],
"with_tn/micro/f1/fp": [0],
"with_tn/micro/f1/tn": [0],
"with_tn/micro/f1/fn": [0],
}
assert metric.compute() == {
"no_relation/f1": tensor(0.0),
Expand All @@ -1737,24 +1741,29 @@ def test_configure_model_metric(documents, taskmodule):
"per:founder/f1": tensor(0.0),
"macro/f1": tensor(0.0),
"micro/f1": tensor(0.0),
"micro/f1_without_tn": tensor(0.0),
}

targets = batch[1]
metric.update(targets, targets)
state = get_metric_state(metric)
assert state == {
"f1_per_label/fn": [0, 0, 0, 0],
"f1_per_label/fp": [0, 0, 0, 0],
"f1_per_label/tn": [7, 5, 4, 5],
"f1_per_label/tp": [0, 2, 3, 2],
"macro/f1/fn": [0, 0, 0, 0],
"macro/f1/fp": [0, 0, 0, 0],
"macro/f1/tn": [7, 5, 4, 5],
"macro/f1/tp": [0, 2, 3, 2],
"micro/f1/fn": [0],
"micro/f1/fp": [0],
"micro/f1/tn": [21],
"micro/f1/tp": [7],
"micro/f1_without_tn/tp": [7],
"micro/f1_without_tn/fp": [0],
"micro/f1_without_tn/tn": [21],
"micro/f1_without_tn/fn": [0],
"with_tn/f1_per_label/tp": [0, 2, 3, 2],
"with_tn/f1_per_label/fp": [0, 0, 0, 0],
"with_tn/f1_per_label/tn": [7, 5, 4, 5],
"with_tn/f1_per_label/fn": [0, 0, 0, 0],
"with_tn/macro/f1/tp": [0, 2, 3, 2],
"with_tn/macro/f1/fp": [0, 0, 0, 0],
"with_tn/macro/f1/tn": [7, 5, 4, 5],
"with_tn/macro/f1/fn": [0, 0, 0, 0],
"with_tn/micro/f1/tp": [7],
"with_tn/micro/f1/fp": [0],
"with_tn/micro/f1/tn": [21],
"with_tn/micro/f1/fn": [0],
}
assert metric.compute() == {
"no_relation/f1": tensor(0.0),
Expand All @@ -1763,37 +1772,46 @@ def test_configure_model_metric(documents, taskmodule):
"per:founder/f1": tensor(1.0),
"macro/f1": tensor(1.0),
"micro/f1": tensor(1.0),
"micro/f1_without_tn": tensor(1.0),
}

metric.reset()
torch.testing.assert_close(targets["labels"], torch.tensor([2, 2, 3, 1, 2, 3, 1]))
# three matches
random_targets = {"labels": torch.tensor([1, 1, 3, 1, 2, 0, 0])}
metric.update(random_targets, targets)
modified_targets = {"labels": torch.tensor([2, 2, 3, 1, 2, 0, 1])}
# three positive matches and one true negative
random_predictions = {"labels": torch.tensor([1, 1, 3, 1, 2, 0, 0])}
metric.update(random_predictions, modified_targets)
state = get_metric_state(metric)
assert state == {
"f1_per_label/fn": [0, 1, 2, 1],
"f1_per_label/fp": [2, 2, 0, 0],
"f1_per_label/tn": [5, 3, 4, 5],
"f1_per_label/tp": [0, 1, 1, 1],
"macro/f1/fn": [0, 1, 2, 1],
"macro/f1/fp": [2, 2, 0, 0],
"macro/f1/tn": [5, 3, 4, 5],
"macro/f1/tp": [0, 1, 1, 1],
"micro/f1/fn": [4],
"micro/f1/fp": [4],
"micro/f1/tn": [17],
"micro/f1/tp": [3],
}
values = {k: v.item() for k, v in metric.compute().items()}
assert values == {
"macro/f1": 0.3916666507720947,
"micro/f1": 0.4285714328289032,
"no_relation/f1": 0.0,
"org:founded_by/f1": 0.4000000059604645,
"per:employee_of/f1": 0.5,
"per:founder/f1": 0.6666666865348816,
"micro/f1_without_tn/tp": [3],
"micro/f1_without_tn/fp": [3],
"micro/f1_without_tn/tn": [15],
"micro/f1_without_tn/fn": [3],
"with_tn/f1_per_label/tp": [1, 1, 1, 1],
"with_tn/f1_per_label/fp": [1, 2, 0, 0],
"with_tn/f1_per_label/tn": [5, 3, 4, 6],
"with_tn/f1_per_label/fn": [0, 1, 2, 0],
"with_tn/macro/f1/tp": [1, 1, 1, 1],
"with_tn/macro/f1/fp": [1, 2, 0, 0],
"with_tn/macro/f1/tn": [5, 3, 4, 6],
"with_tn/macro/f1/fn": [0, 1, 2, 0],
"with_tn/micro/f1/tp": [4],
"with_tn/micro/f1/fp": [3],
"with_tn/micro/f1/tn": [18],
"with_tn/micro/f1/fn": [3],
}
# created with torch.set_printoptions(precision=6)
torch.testing.assert_close(
metric.compute(),
{
"no_relation/f1": tensor(0.666667),
"org:founded_by/f1": tensor(0.400000),
"per:employee_of/f1": tensor(0.500000),
"per:founder/f1": tensor(1.0),
"macro/f1": tensor(0.641667),
"micro/f1": tensor(0.571429),
"micro/f1_without_tn": tensor(0.500000),
},
)

# ensure that the metric can be pickled
pickle.dumps(metric)

0 comments on commit f02c9f7

Please sign in to comment.