From 32290cae4640c56c058842d6e39e8cf124722937 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 16 Sep 2024 16:48:16 +0200 Subject: [PATCH] Fix corner case in `MatthewsCorrcoef` (#2743) * fix + test * changelog * Apply suggestions from code review --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- CHANGELOG.md | 3 +++ .../functional/classification/matthews_corrcoef.py | 14 ++++++++------ .../classification/test_matthews_corrcoef.py | 6 ++++++ 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e49f86d671..1a9d9fc6764 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed flakiness in tests related to `torch.unique` with `dim=None` ([#2650](https://github.com/Lightning-AI/torchmetrics/pull/2650)) +- Fixed corner case in `MatthewsCorrCoef` ([#2743](https://github.com/Lightning-AI/torchmetrics/pull/2743)) + + ## [1.4.1] - 2024-08-02 ### Changed diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 544414ee4a8..45e0238dae5 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -64,12 +64,14 @@ def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: denom = cov_ypyp * cov_ytyt if denom == 0 and confmat.numel() == 4: - if tp == 0 or tn == 0: - a = tp + tn - - if fp == 0 or fn == 0: - b = fp + fn - + if fn == 0 and tn == 0: + a, b = tp, fp + elif fp == 0 and tn == 0: + a, b = tp, fn + elif tp == 0 and fn == 0: + a, b = tn, fp + elif tp == 0 and fp == 0: + a, b = tn, fn eps = torch.tensor(torch.finfo(torch.float32).eps, dtype=torch.float32, device=confmat.device) numerator = torch.sqrt(eps) * (a - b) denom = (tp + fp + eps) * (tp + fn + eps) * (tn + fp + eps) * (tn + fn + eps) diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index 03f649bc0ac..2f881604d09 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -331,6 +331,12 @@ def test_zero_case_in_multiclass(): torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]), 0.0, ), + ( + binary_matthews_corrcoef, + torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]), + torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + 0.0, + ), (binary_matthews_corrcoef, torch.zeros(10), torch.ones(10), -1.0), (binary_matthews_corrcoef, torch.ones(10), torch.zeros(10), -1.0), (