Skip to content

Commit

Permalink
Merge branch 'master' into fix_pesq_nan
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Oct 8, 2024
2 parents 7a67f8f + 1e468f6 commit a596b31
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 18 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed for Pearson changes inputs ([#2765](https://github.com/Lightning-AI/torchmetrics/pull/2765))


## [1.4.2] - 2022-09-12
Expand Down
25 changes: 16 additions & 9 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,23 @@ class MeanAveragePrecision(Metric):
- ``map_dict``: A dictionary containing the following key-values:
- map: (:class:`~torch.Tensor`), global mean average precision
- map_small: (:class:`~torch.Tensor`), mean average precision for small objects
- map_medium:(:class:`~torch.Tensor`), mean average precision for medium objects
- map_large: (:class:`~torch.Tensor`), mean average precision for large objects
- map: (:class:`~torch.Tensor`), global mean average precision which by default is defined as mAP50-95 e.g. the
mean average precision for IoU thresholds 0.50, 0.55, 0.60, ..., 0.95 averaged over all classes and areas. If
the IoU thresholds are changed this value will be calculated with the new thresholds.
- map_small: (:class:`~torch.Tensor`), mean average precision for small objects (area < 32^2 pixels)
- map_medium:(:class:`~torch.Tensor`), mean average precision for medium objects (32^2 pixels < area < 96^2
pixels)
- map_large: (:class:`~torch.Tensor`), mean average precision for large objects (area > 96^2 pixels)
- mar_{mdt[0]}: (:class:`~torch.Tensor`), mean average recall for `max_detection_thresholds[0]` (default 1)
detection per image
- mar_{mdt[1]}: (:class:`~torch.Tensor`), mean average recall for `max_detection_thresholds[1]` (default 10)
detection per image
- mar_{mdt[1]}: (:class:`~torch.Tensor`), mean average recall for `max_detection_thresholds[2]` (default 100)
detection per image
- mar_small: (:class:`~torch.Tensor`), mean average recall for small objects
- mar_medium: (:class:`~torch.Tensor`), mean average recall for medium objects
- mar_large: (:class:`~torch.Tensor`), mean average recall for large objects
- mar_small: (:class:`~torch.Tensor`), mean average recall for small objects (area < 32^2 pixels)
- mar_medium: (:class:`~torch.Tensor`), mean average recall for medium objects (32^2 pixels < area < 96^2
pixels)
- mar_large: (:class:`~torch.Tensor`), mean average recall for large objects (area > 96^2 pixels)
- map_50: (:class:`~torch.Tensor`) (-1 if 0.5 not in the list of iou thresholds), mean average precision at
IoU=0.50
- map_75: (:class:`~torch.Tensor`) (-1 if 0.75 not in the list of iou thresholds), mean average precision at
Expand All @@ -150,8 +154,11 @@ class MeanAveragePrecision(Metric):
For an example on how to use this metric check the `torchmetrics mAP example`_.
.. note::
``map`` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ].
Caution: If the initialization parameters are changed, dictionary keys for mAR can change as well.
``map`` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ]
e.g. the mean average precision for IoU thresholds 0.50, 0.55, 0.60, ..., 0.95 averaged over all classes and
all areas and all max detections per image. If the IoU thresholds are changed this value will be calculated with
the new thresholds. Caution: If the initialization parameters are changed, dictionary keys for mAR can change as
well.
.. note::
This metric supports, at the moment, two different backends for the evaluation. The default backend is
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/functional/image/rmse_sw.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def _rmse_sw_compute(
"""
rmse = rmse_val_sum / total_images if rmse_val_sum is not None else None
if rmse_map is not None:
rmse_map /= total_images
# prevent overwrite the inputs
rmse_map = rmse_map / total_images
return rmse, rmse_map


Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/regression/concordance.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def _concordance_corrcoef_compute(
) -> Tensor:
"""Compute the final concordance correlation coefficient based on accumulated statistics."""
pearson = _pearson_corrcoef_compute(var_x, var_y, corr_xy, nb)
var_x = var_x / (nb - 1)
var_y = var_y / (nb - 1)
return 2.0 * pearson * var_x.sqrt() * var_y.sqrt() / (var_x + var_y + (mean_x - mean_y) ** 2)


Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/functional/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ def _pearson_corrcoef_compute(
nb: number of observations
"""
var_x /= nb - 1
var_y /= nb - 1
corr_xy /= nb - 1
# prevent overwrite the inputs
var_x = var_x / (nb - 1)
var_y = var_y / (nb - 1)
corr_xy = corr_xy / (nb - 1)
# if var_x, var_y is float16 and on cpu, make it bfloat16 as sqrt is not supported for float16
# on cpu, remove this after https://github.com/pytorch/pytorch/issues/54774 is fixed
if var_x.dtype == torch.float16 and var_x.device == torch.device("cpu"):
Expand Down
5 changes: 2 additions & 3 deletions src/torchmetrics/regression/r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class R2Score(Metric):
where the parameter :math:`k` (the number of independent regressors) should be provided as the `adjusted` argument.
The score is only proper defined when :math:`SS_{tot}\neq 0`, which can happen for near constant targets. In this
case a score of 0 is returned. By definition the score is bounded between 0 and 1, where 1 corresponds to the
predictions exactly matching the targets.
case a score of 0 is returned. By definition the score is bounded between :math:`-inf` and 1.0, with 1.0 indicating
perfect prediction, 0 indicating constant prediction and negative values indicating worse than constant prediction.
As input to ``forward`` and ``update`` the metric accepts the following input:
Expand Down Expand Up @@ -99,7 +99,6 @@ class R2Score(Metric):
is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

sum_squared_error: Tensor
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/regression/symmetric_mape.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SymmetricMeanAbsolutePercentageError(Metric):
As output of ``forward`` and ``compute`` the metric returns the following output:
- ``smape`` (:class:`~torch.Tensor`): A tensor with non-negative floating point smape value between 0 and 1
- ``smape`` (:class:`~torch.Tensor`): A tensor with non-negative floating point smape value between 0 and 2
Args:
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand All @@ -60,6 +60,7 @@ class SymmetricMeanAbsolutePercentageError(Metric):
higher_is_better: bool = False
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 2.0

sum_abs_per_error: Tensor
total: Tensor
Expand Down
22 changes: 22 additions & 0 deletions tests/unittests/regression/test_pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,25 @@ def test_single_sample_update():
metric(torch.tensor([7.0]), torch.tensor([8.0]))
res2 = metric.compute()
assert torch.allclose(res1, res2)


def test_overwrite_reference_inputs():
"""Test that the normalizations does not overwrite inputs.
Variables var_x, var_y, corr_xy are references to the object variables and get incorrectly scaled down such that
when you update again and compute you get very wrong values.
"""
y = torch.randn(100)
y_pred = y + torch.randn(y.shape) / 5
# Initialize Pearson correlation coefficient metric
pearson = PearsonCorrCoef()
# Compute the Pearson correlation coefficient
correlation = pearson(y, y_pred)

pearson = PearsonCorrCoef()
for lower, upper in [(0, 33), (33, 66), (66, 99), (99, 100)]:
pearson.update(torch.tensor(y[lower:upper]), torch.tensor(y_pred[lower:upper]))
pearson.compute()

assert torch.isclose(pearson.compute(), correlation)

0 comments on commit a596b31

Please sign in to comment.