Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Newmetric: NRMSE #2442

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
aa1457f
init files
SkafteNicki Mar 9, 2024
58b7577
docs
SkafteNicki Mar 9, 2024
7933f61
requirements for testing
SkafteNicki Mar 9, 2024
ba8848a
changelog
SkafteNicki Mar 9, 2024
0bf7f56
add class interface
SkafteNicki Mar 9, 2024
b6c7011
add tests
SkafteNicki Mar 9, 2024
50e65a5
Merge branch 'master' into newmetric/nrmse
Borda Mar 28, 2024
1100cda
Merge branch 'master' into newmetric/nrmse
Borda Apr 10, 2024
8b962d0
Merge branch 'master' into newmetric/nrmse
SkafteNicki Apr 24, 2024
a95809b
Merge branch 'master' into newmetric/nrmse
Borda Apr 24, 2024
9a40d9b
Merge branch 'master' into newmetric/nrmse
SkafteNicki May 31, 2024
642dd27
Update NRMSE computation with normalization options
SkafteNicki May 31, 2024
b5a2e42
Merge branch 'master' into newmetric/nrmse
SkafteNicki May 31, 2024
84604ad
try fixing docs
SkafteNicki May 31, 2024
0271d91
fix mypy
SkafteNicki May 31, 2024
d279774
fix naming of file
SkafteNicki May 31, 2024
c58fe84
Merge branch 'master' into newmetric/nrmse
Borda Jun 2, 2024
d9b6783
Merge branch 'master' into newmetric/nrmse
Borda Jun 4, 2024
42ffbea
Apply suggestions from code review
Borda Jun 5, 2024
2163188
Merge branch 'master' into newmetric/nrmse
Borda Jul 16, 2024
b5a0e22
Merge branch 'master' into newmetric/nrmse
Borda Jul 16, 2024
5be2218
Merge branch 'master' into newmetric/nrmse
Borda Jul 22, 2024
f194e0f
Merge branch 'master' into newmetric/nrmse
Borda Jul 24, 2024
a863b07
Merge branch 'master' into newmetric/nrmse
Borda Aug 5, 2024
8ec09c9
Merge branch 'master' into newmetric/nrmse
Borda Sep 9, 2024
0693545
add l2 option
SkafteNicki Oct 11, 2024
5267245
added tests for argument error validation
SkafteNicki Oct 11, 2024
c9c8350
Merge branch 'master' into newmetric/nrmse
SkafteNicki Oct 11, 2024
b967aa0
fix doctest
SkafteNicki Oct 11, 2024
4d57398
fix plotting code + test
SkafteNicki Oct 11, 2024
2d26828
fix part of tests
SkafteNicki Oct 11, 2024
6f7c821
fix implementation
SkafteNicki Oct 11, 2024
83e2f77
fix doctest
SkafteNicki Oct 11, 2024
c2714b1
Merge branch 'master' into newmetric/nrmse
SkafteNicki Oct 11, 2024
01d19b6
skip failing tests
SkafteNicki Oct 11, 2024
fae8898
Merge branch 'newmetric/nrmse' of https://github.com/Lightning-AI/tor…
SkafteNicki Oct 11, 2024
b7a116d
fix ddp testing
SkafteNicki Oct 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `input_format` argument to segmentation metrics ([#2572](https://github.com/Lightning-AI/torchmetrics/pull/2572))


- Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442))


### Changed

-
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,5 @@
.. _FLORES-200: https://arxiv.org/abs/2207.04672
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013
.. _Normalized Root Mean Squared Error: https://onlinelibrary.wiley.com/doi/abs/10.1111/1365-2478.12109
.. _Generalized Dice Score: https://arxiv.org/abs/1707.03237
21 changes: 21 additions & 0 deletions docs/source/regression/normalized_root_mean_squared_error.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Normalized Root Mean Squared Error (NRMSE)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Regression

.. include:: ../links.rst

##########################################
Normalized Root Mean Squared Error (NRMSE)
##########################################

Module Interface
________________

.. autoclass:: torchmetrics.NormalizedRootMeanSquaredError
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.normalized_root_mean_squared_error
1 change: 1 addition & 0 deletions requirements/_devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
-r classification_test.txt
-r nominal_test.txt
-r segmentation_test.txt
-r regression_test.txt
1 change: 1 addition & 0 deletions requirements/regression_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
permetrics==2.0.0
36 changes: 19 additions & 17 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
MeanSquaredError,
MeanSquaredLogError,
MinkowskiDistance,
NormalizedRootMeanSquaredError,
PearsonCorrCoef,
R2Score,
RelativeSquaredError,
Expand Down Expand Up @@ -151,25 +152,23 @@
)

__all__ = [
"functional",
"Accuracy",
"AUROC",
"Accuracy",
"AveragePrecision",
"BLEUScore",
"BootStrapper",
"CHRFScore",
"CalibrationError",
"CatMetric",
"ClasswiseWrapper",
"CharErrorRate",
"CHRFScore",
"ConcordanceCorrCoef",
"ClasswiseWrapper",
"CohenKappa",
"ConcordanceCorrCoef",
"ConfusionMatrix",
"CosineSimilarity",
"CramersV",
"CriticalSuccessIndex",
"Dice",
"TweedieDevianceScore",
"ErrorRelativeGlobalDimensionlessSynthesis",
"ExactMatch",
"ExplainedVariance",
Expand All @@ -180,8 +179,8 @@
"HammingDistance",
"HingeLoss",
"JaccardIndex",
"KendallRankCorrCoef",
"KLDivergence",
"KendallRankCorrCoef",
"LogCoshError",
"MatchErrorRate",
"MatthewsCorrCoef",
Expand All @@ -194,23 +193,25 @@
"Metric",
"MetricCollection",
"MetricTracker",
"MinkowskiDistance",
"MinMaxMetric",
"MinMetric",
"MinkowskiDistance",
"ModifiedPanopticQuality",
"MultiScaleStructuralSimilarityIndexMeasure",
"MultioutputWrapper",
"MultitaskWrapper",
"MultiScaleStructuralSimilarityIndexMeasure",
"NormalizedRootMeanSquaredError",
"PanopticQuality",
"PeakSignalNoiseRatio",
"PearsonCorrCoef",
"PearsonsContingencyCoefficient",
"PermutationInvariantTraining",
"Perplexity",
"Precision",
"PrecisionAtFixedRecall",
"PrecisionRecallCurve",
"PeakSignalNoiseRatio",
"R2Score",
"ROC",
"Recall",
"RecallAtFixedPrecision",
"RelativeAverageSpectralError",
Expand All @@ -221,37 +222,38 @@
"RetrievalMRR",
"RetrievalNormalizedDCG",
"RetrievalPrecision",
"RetrievalRecall",
"RetrievalRPrecision",
"RetrievalPrecisionRecallCurve",
"RetrievalRPrecision",
"RetrievalRecall",
"RetrievalRecallAtFixedPrecision",
"ROC",
"RootMeanSquaredErrorUsingSlidingWindow",
"RunningMean",
"RunningSum",
"SQuAD",
"SacreBLEUScore",
"SignalDistortionRatio",
"ScaleInvariantSignalDistortionRatio",
"ScaleInvariantSignalNoiseRatio",
"SensitivityAtSpecificity",
"SignalDistortionRatio",
"SignalNoiseRatio",
"SpearmanCorrCoef",
"Specificity",
"SpecificityAtSensitivity",
"SensitivityAtSpecificity",
"SpectralAngleMapper",
"SpectralDistortionIndex",
"SQuAD",
"StructuralSimilarityIndexMeasure",
"StatScores",
"StructuralSimilarityIndexMeasure",
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"TheilsU",
"TotalVariation",
"TranslationEditRate",
"TschuprowsT",
"TweedieDevianceScore",
"UniversalImageQualityIndex",
"WeightedMeanAbsolutePercentageError",
"WordErrorRate",
"WordInfoLost",
"WordInfoPreserved",
"functional",
]
24 changes: 13 additions & 11 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
mean_squared_error,
mean_squared_log_error,
minkowski_distance,
normalized_root_mean_squared_error,
pearson_corrcoef,
r2_score,
relative_squared_error,
Expand Down Expand Up @@ -146,14 +147,13 @@
"calibration_error",
"char_error_rate",
"chrf_score",
"concordance_corrcoef",
"cohen_kappa",
"concordance_corrcoef",
"confusion_matrix",
"cosine_similarity",
"cramers_v",
"cramers_v_matrix",
"critical_success_index",
"tweedie_deviance_score",
"dice",
"error_relative_global_dimensionless_synthesis",
"exact_match",
Expand All @@ -177,63 +177,65 @@
"mean_squared_log_error",
"minkowski_distance",
"multiscale_structural_similarity_index_measure",
"normalized_root_mean_squared_error",
"pairwise_cosine_similarity",
"pairwise_euclidean_distance",
"pairwise_linear_similarity",
"pairwise_manhattan_distance",
"pairwise_minkowski_distance",
"panoptic_quality",
"peak_signal_noise_ratio",
"pearson_corrcoef",
"pearsons_contingency_coefficient",
"pearsons_contingency_coefficient_matrix",
"permutation_invariant_training",
"perplexity",
"pit_permutate",
"precision",
"precision_at_fixed_recall",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall this be another PR?

"precision_recall_curve",
"peak_signal_noise_ratio",
"r2_score",
"recall",
"recall_at_fixed_precision",
"relative_average_spectral_error",
"relative_squared_error",
"retrieval_average_precision",
"retrieval_fall_out",
"retrieval_hit_rate",
"retrieval_normalized_dcg",
"retrieval_precision",
"retrieval_precision_recall_curve",
"retrieval_r_precision",
"retrieval_recall",
"retrieval_reciprocal_rank",
"retrieval_precision_recall_curve",
"roc",
"root_mean_squared_error_using_sliding_window",
"rouge_score",
"sacre_bleu_score",
"signal_distortion_ratio",
"scale_invariant_signal_distortion_ratio",
"scale_invariant_signal_noise_ratio",
"sensitivity_at_specificity",
"signal_distortion_ratio",
"signal_noise_ratio",
"spearman_corrcoef",
"specificity",
"specificity_at_sensitivity",
"spectral_angle_mapper",
"spectral_distortion_index",
"squad",
"structural_similarity_index_measure",
"stat_scores",
"structural_similarity_index_measure",
"symmetric_mean_absolute_percentage_error",
"theils_u",
"theils_u_matrix",
"total_variation",
"translation_edit_rate",
"tschuprows_t",
"tschuprows_t_matrix",
"tweedie_deviance_score",
"universal_image_quality_index",
"spectral_angle_mapper",
"weighted_mean_absolute_percentage_error",
"word_error_rate",
"word_information_lost",
"word_information_preserved",
"precision_at_fixed_recall",
"recall_at_fixed_precision",
"sensitivity_at_specificity",
"specificity_at_sensitivity",
]
8 changes: 5 additions & 3 deletions src/torchmetrics/functional/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchmetrics.functional.regression.mape import mean_absolute_percentage_error
from torchmetrics.functional.regression.minkowski import minkowski_distance
from torchmetrics.functional.regression.mse import mean_squared_error
from torchmetrics.functional.regression.nrmse import normalized_root_mean_squared_error
from torchmetrics.functional.regression.pearson import pearson_corrcoef
from torchmetrics.functional.regression.r2 import r2_score
from torchmetrics.functional.regression.rse import relative_squared_error
Expand All @@ -39,13 +40,14 @@
"kendall_rank_corrcoef",
"kl_divergence",
"log_cosh_error",
"mean_squared_log_error",
"mean_absolute_error",
"mean_squared_error",
"pearson_corrcoef",
"mean_absolute_percentage_error",
"mean_absolute_percentage_error",
"mean_squared_error",
"mean_squared_log_error",
"minkowski_distance",
"normalized_root_mean_squared_error",
"pearson_corrcoef",
"r2_score",
"relative_squared_error",
"spearman_corrcoef",
Expand Down
100 changes: 100 additions & 0 deletions src/torchmetrics/functional/regression/nrmse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.regression.mse import _mean_squared_error_update


def _normalized_root_mean_squared_error_update(
preds: Tensor, target: Tensor, num_outputs: int, normalization: Literal["mean", "range", "std"] = "mean"
) -> Tuple[Tensor, int, Tensor]:
"""Updates and returns the sum of squared errors and the number of observations for NRMSE computation.

Args:
preds: Predicted tensor
target: Ground truth tensor
num_outputs: Number of outputs in multioutput setting
normalization: type of normalization to be applied. Choose from "mean", "range", "std"

"""
sum_squared_error, num_obs = _mean_squared_error_update(preds, target, num_outputs)

target = target.view(-1) if num_outputs == 1 else target
if normalization == "mean":
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
denom = torch.mean(target, dim=0)
elif normalization == "range":
denom = torch.max(target, dim=0).values - torch.min(target, dim=0).values
elif normalization == "std":
denom = torch.std(target, correction=0, dim=0)
else:
raise ValueError(f"Argument `normalization` should be either 'mean', 'range' or 'std', but got {normalization}")
return sum_squared_error, num_obs, denom


def _normalized_root_mean_squared_error_compute(
sum_squared_error: Tensor, num_obs: Union[int, Tensor], denom: Tensor
) -> Tensor:
"""Calculates RMSE and normalizes it."""
rmse = torch.sqrt(sum_squared_error / num_obs)
return rmse / denom


def normalized_root_mean_squared_error(
preds: Tensor,
target: Tensor,
normalization: Literal["mean", "range", "std"] = "mean",
num_outputs: int = 1,
) -> Tensor:
"""Calculates the `Normalized Root Mean Squared Error`_ (NRMSE) also know as scatter index.

Args:
preds: estimated labels
target: ground truth labels
normalization: type of normalization to be applied. Choose from "mean", "range", "std" which corresponds to
normalizing the RMSE by the mean of the target, the range of the target or the standard deviation of the
target.
num_outputs: Number of outputs in multioutput setting

Return:
Tensor with the NRMSE score

Example:
>>> import torch
>>> from torchmetrics.functional.regression import normalized_root_mean_squared_error
>>> preds = torch.tensor([0., 1, 2, 3])
>>> target = torch.tensor([0., 1, 2, 2])
>>> normalized_root_mean_squared_error(preds, target, normalization="mean")
tensor(0.4000)
>>> normalized_root_mean_squared_error(preds, target, normalization="range")
tensor(0.2500)
>>> normalized_root_mean_squared_error(preds, target, normalization="std")
tensor(0.6030)

Example (multioutput):
>>> import torch
>>> from torchmetrics.functional.regression import normalized_root_mean_squared_error
>>> preds = torch.tensor([[0., 1], [2, 3], [4, 5], [6, 7]])
>>> target = torch.tensor([[0., 1], [3, 3], [4, 5], [8, 9]])
>>> normalized_root_mean_squared_error(preds, target, normalization="mean", num_outputs=2)
tensor([0.2981, 0.2222])

"""
sum_squared_error, num_obs, denom = _normalized_root_mean_squared_error_update(
preds, target, num_outputs=num_outputs, normalization=normalization
)
return _normalized_root_mean_squared_error_compute(sum_squared_error, num_obs, denom)
Loading
Loading