Skip to content

Commit

Permalink
'fix-bug'
Browse files Browse the repository at this point in the history
  • Loading branch information
superarthurlx committed Jan 2, 2025
2 parents f4ef802 + 489884e commit f4bd9f3
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion basicts/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from .rmse import masked_rmse
from .wape import masked_wape
from .smape import masked_smape
from .r_square import masked_r2
from .corr import masked_corr
from r_square import masked_r2

ALL_METRICS = {
'MAE': masked_mae,
Expand Down
1 change: 1 addition & 0 deletions basicts/metrics/r_square.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ def masked_r2(prediction: torch.Tensor, target: torch.Tensor, null_val: float =

# 计算 R^2
loss = 1 - (ss_res / (ss_tot + eps))
loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero
return torch.mean(loss)
1 change: 1 addition & 0 deletions basicts/metrics/smape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import torch


def masked_smape(prediction: torch.Tensor, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor:
"""
Calculate the Masked Symmetric Mean Absolute Percentage Error (SMAPE) between predicted and target values,
Expand Down

0 comments on commit f4bd9f3

Please sign in to comment.