Skip to content

Commit

Permalink
Merge pull request #47 from ArneBinder/exclude_none_class_for_metrics
Browse files Browse the repository at this point in the history
Exclude none class for metrics (WIP)
  • Loading branch information
ArneBinder authored Jan 31, 2024
2 parents 5aee5b1 + 1ebb947 commit f1b9179
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -406,22 +406,19 @@ def create_annotations_from_output(
yield self.span_annotation, span.copy()

def configure_model_metric(self, stage: str) -> Union[Metric, MetricCollection]:
common_metric_kwargs = {
"num_classes": len(self.label_to_id),
"task": "multiclass",
"ignore_index": self.label_to_id["O"],
}
token_scores = MetricCollection(
{
"token/macro/f1": WrappedMetricWithPrepareFunction(
metric=F1Score(
num_classes=len(self.label_to_id),
task="multiclass",
average="macro",
),
metric=F1Score(average="macro", **common_metric_kwargs),
prepare_function=partial(remove_label_pad_ids, label_pad_id=self.label_pad_id),
),
"token/micro/f1": WrappedMetricWithPrepareFunction(
metric=F1Score(
num_classes=len(self.label_to_id),
task="multiclass",
average="micro",
),
metric=F1Score(average="micro", **common_metric_kwargs),
prepare_function=partial(remove_label_pad_ids, label_pad_id=self.label_pad_id),
),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -979,15 +979,18 @@ def configure_model_metric(self, stage: str) -> Metric:
)
# we use the length of label_to_id because that contains the none_label (in contrast to labels)
labels = [self.id_to_label[i] for i in range(len(self.label_to_id))]
num_classes = len(labels)
task = "multilabel" if self.multi_label else "multiclass"
common_metric_kwargs = {
"num_classes": len(labels),
"task": "multilabel" if self.multi_label else "multiclass",
"ignore_index": self.label_to_id[self.none_label],
}
return WrappedMetricWithPrepareFunction(
metric=MetricCollection(
{
"micro/f1": F1Score(num_classes=num_classes, task=task, average="micro"),
"macro/f1": F1Score(num_classes=num_classes, task=task, average="macro"),
"micro/f1": F1Score(average="micro", **common_metric_kwargs),
"macro/f1": F1Score(average="macro", **common_metric_kwargs),
"f1_per_label": ClasswiseWrapper(
F1Score(num_classes=num_classes, task=task, average=None),
F1Score(average=None, **common_metric_kwargs),
labels=labels,
postfix="/f1",
),
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_simple_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ def test_validation_step_and_on_epoch_end(batch, model, config):
"span/micro/f1": 0.13793103396892548,
"span/micro/precision": 0.3333333432674408,
"span/micro/recall": 0.08695652335882187,
"token/macro/f1": 0.04210526496171951,
"token/micro/f1": 0.06896551698446274,
"token/macro/f1": 0.08888889104127884,
"token/micro/f1": 0.13333334028720856,
}

model.on_validation_epoch_end()
Expand All @@ -319,8 +319,8 @@ def test_test_step_and_on_epoch_end(batch, model, config):
"span/micro/f1": 0.13793103396892548,
"span/micro/precision": 0.3333333432674408,
"span/micro/recall": 0.08695652335882187,
"token/macro/f1": 0.04210526496171951,
"token/micro/f1": 0.06896551698446274,
"token/macro/f1": 0.08888889104127884,
"token/micro/f1": 0.13333334028720856,
}

model.on_test_epoch_end()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ def test_validation_step_and_on_epoch_end(batch, model, config):
if config == {}:
torch.testing.assert_close(loss, torch.tensor(59.42658996582031))
assert metric_values == {
"token/macro/f1": 0.19285714626312256,
"token/micro/f1": 0.27586206793785095,
"token/macro/f1": 0.3919413983821869,
"token/micro/f1": 0.5333333611488342,
"span/PER/f1": 0.0833333358168602,
"span/PER/recall": 0.0476190485060215,
"span/PER/precision": 0.3333333432674408,
Expand All @@ -371,8 +371,8 @@ def test_validation_step_and_on_epoch_end(batch, model, config):
elif config == {"use_crf": False}:
torch.testing.assert_close(loss, torch.tensor(1.6708829402923584))
assert metric_values == {
"token/macro/f1": 0.08615384995937347,
"token/micro/f1": 0.13793103396892548,
"token/macro/f1": 0.14374999701976776,
"token/micro/f1": 0.2666666805744171,
"span/PER/f1": 0.0,
"span/PER/recall": 0.0,
"span/PER/precision": 0.0,
Expand Down Expand Up @@ -401,8 +401,8 @@ def test_test_step_and_on_epoch_end(batch, model, config):
if config == {}:
torch.testing.assert_close(loss, torch.tensor(59.42658996582031))
assert metric_values == {
"token/macro/f1": 0.19285714626312256,
"token/micro/f1": 0.27586206793785095,
"token/macro/f1": 0.3919413983821869,
"token/micro/f1": 0.5333333611488342,
"span/ORG/f1": 0.0,
"span/ORG/recall": 0.0,
"span/ORG/precision": 0.0,
Expand All @@ -419,8 +419,8 @@ def test_test_step_and_on_epoch_end(batch, model, config):
elif config == {"use_crf": False}:
torch.testing.assert_close(loss, torch.tensor(1.6708829402923584))
assert metric_values == {
"token/macro/f1": 0.08615384995937347,
"token/micro/f1": 0.13793103396892548,
"token/macro/f1": 0.14374999701976776,
"token/micro/f1": 0.2666666805744171,
"span/ORG/f1": 0.0,
"span/ORG/recall": 0.0,
"span/ORG/precision": 0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -787,8 +787,8 @@ def test_configure_model_metric(documents):
values = metric.compute()
values_converted = {k: v.item() for k, v in values.items()}
assert values_converted == {
"token/macro/f1": 0.5434783101081848,
"token/micro/f1": 0.5249999761581421,
"token/macro/f1": 0.6349206566810608,
"token/micro/f1": 0.625,
"span/LOC/recall": 0.0476190485060215,
"span/LOC/precision": 0.5,
"span/LOC/f1": 0.08695652335882187,
Expand Down

0 comments on commit f1b9179

Please sign in to comment.