Skip to content

Commit

Permalink
docs: ✏️ update the loss function of STWave
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Dec 4, 2023
1 parent 13822ca commit 88e0bf2
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 24 deletions.
2 changes: 1 addition & 1 deletion baselines/STWave/METR-LA.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from easydict import EasyDict
from basicts.runners import SimpleTimeSeriesForecastingRunner
from basicts.data import TimeSeriesForecastingDataset
from basicts.losses import stwave_masked_mae
from basicts.utils import load_adj

from .arch import STWave
from .loss import stwave_masked_mae

def laplacian(W):
"""Return the Laplacian of the weight matrix."""
Expand Down
2 changes: 1 addition & 1 deletion baselines/STWave/PEMS-BAY.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from easydict import EasyDict
from basicts.runners import SimpleTimeSeriesForecastingRunner
from basicts.data import TimeSeriesForecastingDataset
from basicts.losses import stwave_masked_mae
from basicts.utils import load_adj

from .arch import STWave
from .loss import stwave_masked_mae

def laplacian(W):
"""Return the Laplacian of the weight matrix."""
Expand Down
2 changes: 1 addition & 1 deletion baselines/STWave/PEMS03.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from easydict import EasyDict
from basicts.runners import SimpleTimeSeriesForecastingRunner
from basicts.data import TimeSeriesForecastingDataset
from basicts.losses import stwave_masked_mae
from basicts.utils import load_adj

from .arch import STWave
from .loss import stwave_masked_mae

def laplacian(W):
"""Return the Laplacian of the weight matrix."""
Expand Down
2 changes: 1 addition & 1 deletion baselines/STWave/PEMS04.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from easydict import EasyDict
from basicts.runners import SimpleTimeSeriesForecastingRunner
from basicts.data import TimeSeriesForecastingDataset
from basicts.losses import stwave_masked_mae
from basicts.utils import load_adj

from .arch import STWave
from .loss import stwave_masked_mae

def laplacian(W):
"""Return the Laplacian of the weight matrix."""
Expand Down
2 changes: 1 addition & 1 deletion baselines/STWave/PEMS07.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from easydict import EasyDict
from basicts.runners import SimpleTimeSeriesForecastingRunner
from basicts.data import TimeSeriesForecastingDataset
from basicts.losses import stwave_masked_mae
from basicts.utils import load_adj

from .arch import STWave
from .loss import stwave_masked_mae

def laplacian(W):
"""Return the Laplacian of the weight matrix."""
Expand Down
2 changes: 1 addition & 1 deletion baselines/STWave/PEMS08.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from easydict import EasyDict
from basicts.runners import SimpleTimeSeriesForecastingRunner
from basicts.data import TimeSeriesForecastingDataset
from basicts.losses import stwave_masked_mae
from basicts.utils import load_adj

from .arch import STWave
from .loss import stwave_masked_mae

def laplacian(W):
"""Return the Laplacian of the weight matrix."""
Expand Down
21 changes: 21 additions & 0 deletions baselines/STWave/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
import numpy as np

from basicts.losses import masked_mae


def stwave_masked_mae(preds: list, labels: torch.Tensor, null_val: float = np.nan) -> torch.Tensor:
"""Masked mean absolute error.
Args:
preds (torch.Tensor): predicted values
labels (torch.Tensor): labels
null_val (float, optional): null value. Defaults to np.nan.
Returns:
torch.Tensor: masked mean absolute error
"""
lloss = masked_mae(preds[...,1:2], preds[...,2:])
loss = masked_mae(preds[...,:1], labels)

return loss + lloss
4 changes: 2 additions & 2 deletions basicts/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .losses import l1_loss, l2_loss, masked_mae, masked_mape, masked_rmse, masked_mse, stwave_masked_mae
from .losses import l1_loss, l2_loss, masked_mae, masked_mape, masked_rmse, masked_mse

__all__ = ["l1_loss", "l2_loss", "masked_mae", "masked_mape", "masked_rmse", "masked_mse", "stwave_masked_mae"]
__all__ = ["l1_loss", "l2_loss", "masked_mae", "masked_mape", "masked_rmse", "masked_mse"]
16 changes: 0 additions & 16 deletions basicts/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,6 @@ def masked_mae(preds: torch.Tensor, labels: torch.Tensor, null_val: float = np.n
loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
return torch.mean(loss)

def stwave_masked_mae(preds: list, labels: torch.Tensor, null_val: float = np.nan) -> torch.Tensor:
"""Masked mean absolute error.
Args:
preds (torch.Tensor): predicted values
labels (torch.Tensor): labels
null_val (float, optional): null value. Defaults to np.nan.
Returns:
torch.Tensor: masked mean absolute error
"""
lloss = masked_mae(preds[...,1:2], preds[...,2:])
loss = masked_mae(preds[...,:1], labels)

return loss + lloss


def masked_mse(preds: torch.Tensor, labels: torch.Tensor, null_val: float = np.nan) -> torch.Tensor:
"""Masked mean squared error.
Expand Down

0 comments on commit 88e0bf2

Please sign in to comment.