From 75eae85b7317b6a27de0100aeaeff068f6e1c6a1 Mon Sep 17 00:00:00 2001 From: Zezhi Shao <864453277@qq.com> Date: Wed, 11 Dec 2024 13:07:55 +0800 Subject: [PATCH 1/2] =?UTF-8?q?docs:=20=E2=9C=8F=EF=B8=8F=20fix=20lint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- basicts/metrics/corr.py | 2 +- basicts/metrics/r_square.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/basicts/metrics/corr.py b/basicts/metrics/corr.py index 95cbdcd..4e10ee4 100644 --- a/basicts/metrics/corr.py +++ b/basicts/metrics/corr.py @@ -47,4 +47,4 @@ def masked_corr(prediction: torch.Tensor, target: torch.Tensor, null_val: float loss = loss * mask # Apply the mask to the loss loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero - return torch.mean(loss) \ No newline at end of file + return torch.mean(loss) diff --git a/basicts/metrics/r_square.py b/basicts/metrics/r_square.py index 5034440..2d3ce51 100644 --- a/basicts/metrics/r_square.py +++ b/basicts/metrics/r_square.py @@ -36,9 +36,9 @@ def masked_r2(prediction: torch.Tensor, target: torch.Tensor, null_val: float = ss_res = torch.sum(torch.pow((target - prediction), 2), dim=1) # 残差平方和 ss_tot = torch.sum(torch.pow(target - torch.mean(target, dim=1, keepdim=True), 2), dim=1) # 总平方和 - + # 计算 R^2 loss = 1 - (ss_res / (ss_tot + eps)) - + loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero return torch.mean(loss) From 59a0f95d8ac3639b649cd6de09eb2dc037bf4afa Mon Sep 17 00:00:00 2001 From: Zezhi Shao <864453277@qq.com> Date: Wed, 11 Dec 2024 13:13:29 +0800 Subject: [PATCH 2/2] =?UTF-8?q?docs:=20=E2=9C=8F=EF=B8=8F=20fix=20isort?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- basicts/metrics/__init__.py | 6 +++--- basicts/metrics/smape.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/basicts/metrics/__init__.py b/basicts/metrics/__init__.py index 4d88ba8..66e465b 100644 --- a/basicts/metrics/__init__.py +++ b/basicts/metrics/__init__.py @@ -1,11 +1,11 @@ +from .corr import masked_corr from .mae import masked_mae from .mape import masked_mape from .mse import masked_mse +from .r_square import masked_r2 from .rmse import masked_rmse -from .wape import masked_wape from .smape import masked_smape -from .r_square import masked_r2 -from .corr import masked_corr +from .wape import masked_wape ALL_METRICS = { 'MAE': masked_mae, diff --git a/basicts/metrics/smape.py b/basicts/metrics/smape.py index 48e4166..ebc2308 100644 --- a/basicts/metrics/smape.py +++ b/basicts/metrics/smape.py @@ -1,5 +1,6 @@ -import torch import numpy as np +import torch + def masked_smape(prediction: torch.Tensor, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor: """