diff --git a/pylops/utils/metrics.py b/pylops/utils/metrics.py index e1e7a55c..6393dec2 100644 --- a/pylops/utils/metrics.py +++ b/pylops/utils/metrics.py @@ -5,10 +5,13 @@ "psnr", ] +from typing import Optional + import numpy as np +import numpy.typing as npt -def mae(xref, xcmp): +def mae(xref: npt.ArrayLike, xcmp: npt.ArrayLike) -> float: """Mean Absolute Error (MAE) Compute Mean Absolute Error between two vectors @@ -30,7 +33,7 @@ def mae(xref, xcmp): return mae -def mse(xref, xcmp): +def mse(xref: npt.ArrayLike, xcmp: npt.ArrayLike) -> float: """Mean Square Error (MSE) Compute Mean Square Error between two vectors @@ -52,7 +55,7 @@ def mse(xref, xcmp): return mse -def snr(xref, xcmp): +def snr(xref: npt.ArrayLike, xcmp: npt.ArrayLike) -> float: """Signal to Noise Ratio (SNR) Compute Signal to Noise Ratio between two vectors @@ -75,7 +78,12 @@ def snr(xref, xcmp): return snr -def psnr(xref, xcmp, xmax=None, xmin=0.0): +def psnr( + xref: npt.ArrayLike, + xcmp: npt.ArrayLike, + xmax: Optional[float] = None, + xmin: Optional[float] = 0.0, +) -> float: """Peak Signal to Noise Ratio (PSNR) Compute Peak Signal to Noise Ratio between two vectors