Skip to content

Commit

Permalink
feat: 🎸 Compatible with BasicTS 0.3.5
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Dec 10, 2023
1 parent 293809a commit 9129716
Show file tree
Hide file tree
Showing 29 changed files with 118 additions and 344 deletions.
17 changes: 8 additions & 9 deletions baselines/DGCRN/runner/dgcrn_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: b
train (bool, optional): if in the training process. Defaults to True.
Returns:
tuple: (prediction, real_value)
dict: keys that must be included: inputs, prediction, target
"""

# preprocess
Expand All @@ -35,12 +35,11 @@ def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: b

# customized curriculum learning
task_level = self.curriculum_learning(epoch)
prediction_data = self.model(history_data=history_data, future_data=future_data_4_dec, batch_seen=iter_num, epoch=epoch, train=train, task_level=task_level)
# feed forward
assert list(prediction_data.shape)[:3] == [batch_size, length, num_nodes], \
model_return = self.model(history_data=history_data, future_data=future_data_4_dec, batch_seen=iter_num, epoch=epoch, train=train, task_level=task_level)
# parse model return
if isinstance(model_return, torch.Tensor): model_return = {"prediction": model_return}
model_return["inputs"] = self.select_target_features(history_data)
model_return["target"] = self.select_target_features(future_data)
assert list(model_return["prediction"].shape)[:3] == [batch_size, length, num_nodes], \
"error shape of the output, edit the forward function to reshape it to [B, L, N, C]"
# post process
prediction = self.select_target_features(prediction_data)
real_value = self.select_target_features(future_data)
return prediction, real_value

return model_return
5 changes: 2 additions & 3 deletions baselines/DeepAR/arch/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, train:
for t in range(1, len_in + len_out):
if not (t > len_in and not train): # not in the decoding stage when inferecing
history_next = input_feat_full[:, t-1:t, :, 0:1]
else:
a = 1
embed_feat = self.input_embed(history_next)
covar_feat = covar_feat_full[:, t:t+1, :, :]
if self.use_ts_id:
Expand All @@ -98,4 +96,5 @@ def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, train:
mus = torch.concat(mus, dim=1)
sigmas = torch.concat(sigmas, dim=1)
reals = input_feat_full[:, -preds.shape[1]:, :, :]
return preds, reals, mus, sigmas

return {"prediction": preds, "target": reals, "mus": mus, "sigmas": sigmas}
18 changes: 9 additions & 9 deletions baselines/DeepAR/loss/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,28 @@
import numpy as np


def gaussian_loss(prediction, real_value, mu, sigma, null_val = np.nan):
"""Masked gaussian loss. Kindly note that the gaussian loss is calculated based on mu, sigma, and real_value. The prediction is sampled from N(mu, sigma), and is not used in the loss calculation (it will be used in the metrics calculation).
def gaussian_loss(prediction, target, mus, sigmas, null_val = np.nan):
"""Masked gaussian loss. Kindly note that the gaussian loss is calculated based on mu, sigma, and target. The prediction is sampled from N(mu, sigma), and is not used in the loss calculation (it will be used in the metrics calculation).
Args:
prediction (torch.Tensor): prediction of model. [B, L, N, 1].
real_value (torch.Tensor): ground truth. [B, L, N, 1].
mu (torch.Tensor): the mean of gaussian distribution. [B, L, N, 1].
sigma (torch.Tensor): the std of gaussian distribution. [B, L, N, 1]
target (torch.Tensor): ground truth. [B, L, N, 1].
mus (torch.Tensor): the mean of gaussian distribution. [B, L, N, 1].
sigmas (torch.Tensor): the std of gaussian distribution. [B, L, N, 1]
null_val (optional): null value. Defaults to np.nan.
"""
# mask
if np.isnan(null_val):
mask = ~torch.isnan(real_value)
mask = ~torch.isnan(target)
else:
eps = 5e-5
mask = ~torch.isclose(real_value, torch.tensor(null_val).expand_as(real_value).to(real_value.device), atol=eps, rtol=0.)
mask = ~torch.isclose(target, torch.tensor(null_val).expand_as(target).to(target.device), atol=eps, rtol=0.)
mask = mask.float()
mask /= torch.mean((mask))
mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)

distribution = torch.distributions.Normal(mu, sigma)
likelihood = distribution.log_prob(real_value)
distribution = torch.distributions.Normal(mus, sigmas)
likelihood = distribution.log_prob(target)
likelihood = likelihood * mask
loss_g = -torch.mean(likelihood)
return loss_g
68 changes: 38 additions & 30 deletions baselines/DeepAR/runner/deepar_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Dict
import torch
from basicts.data.registry import SCALER_REGISTRY
from easytorch.utils.dist import master_only
Expand Down Expand Up @@ -42,22 +42,25 @@ def select_target_features(self, data: torch.Tensor) -> torch.Tensor:
data = data[:, :, :, self.target_features]
return data

def rescale_data(self, input_data: List[torch.Tensor]) -> List[torch.Tensor]:
def rescale_data(self, input_data: Dict) -> Dict:
"""Rescale data.
Args:
data (List[torch.Tensor]): list of data to be re-scaled.
data (Dict): Dict of data to be re-scaled.
Returns:
List[torch.Tensor]: list of re-scaled data.
Dict: Dict re-scaled data.
"""
prediction, real_value, mus, sigmas = input_data

if self.if_rescale:
prediction = SCALER_REGISTRY.get(self.scaler["func"])(prediction, **self.scaler["args"])
real_value = SCALER_REGISTRY.get(self.scaler["func"])(real_value, **self.scaler["args"])
mus = SCALER_REGISTRY.get(self.scaler["func"])(mus, **self.scaler["args"])
sigmas = SCALER_REGISTRY.get(self.scaler["func"])(sigmas, **self.scaler["args"])
return [prediction, real_value, mus, sigmas]
input_data["inputs"] = SCALER_REGISTRY.get(self.scaler["func"])(input_data["inputs"], **self.scaler["args"])
input_data["prediction"] = SCALER_REGISTRY.get(self.scaler["func"])(input_data["prediction"], **self.scaler["args"])
input_data["target"] = SCALER_REGISTRY.get(self.scaler["func"])(input_data["target"], **self.scaler["args"])
if "mus" in input_data.keys():
input_data["mus"] = SCALER_REGISTRY.get(self.scaler["func"])(input_data["mus"], **self.scaler["args"])
if "sigmas" in input_data.keys():
input_data["sigmas"] = SCALER_REGISTRY.get(self.scaler["func"])(input_data["sigmas"], **self.scaler["args"])
return input_data

@torch.no_grad()
@master_only
Expand All @@ -69,22 +72,25 @@ def test(self):
"""

# test loop
prediction = []
real_value = []
prediction =[]
target = []
inputs = []
for _, data in enumerate(self.test_data_loader):
forward_return = list(self.forward(data, epoch=None, iter_num=None, train=False))
forward_return = self.forward(data, epoch=None, iter_num=None, train=False)
if not self.if_evaluate_on_gpu:
forward_return[0], forward_return[1] = forward_return[0].detach().cpu(), forward_return[1].detach().cpu()
prediction.append(forward_return[0]) # preds = forward_return[0]
real_value.append(forward_return[1]) # testy = forward_return[1]
forward_return["prediction"] = forward_return["prediction"].detach().cpu()
forward_return["target"] = forward_return["target"].detach().cpu()
forward_return["inputs"] = forward_return["inputs"].detach().cpu()
prediction.append(forward_return["prediction"])
target.append(forward_return["target"])
inputs.append(forward_return["inputs"])
prediction = torch.cat(prediction, dim=0)
real_value = torch.cat(real_value, dim=0)
target = torch.cat(target, dim=0)
inputs = torch.cat(inputs, dim=0)
# re-scale data
if self.if_rescale:
prediction = SCALER_REGISTRY.get(self.scaler["func"])(prediction, **self.scaler["args"])[:, -self.output_seq_len:, :, :]
real_value = SCALER_REGISTRY.get(self.scaler["func"])(real_value, **self.scaler["args"])[:, -self.output_seq_len:, :, :]
returns_all = self.rescale_data({"prediction": prediction[:, -self.output_seq_len:, :, :], "target": target[:, -self.output_seq_len:, :, :], "inputs": inputs})
# evaluate
self.evaluate(prediction, real_value)
self.evaluate(returns_all)

def forward(self, data: tuple, epoch:int = None, iter_num: int = None, train:bool = True, **kwargs) -> tuple:
"""feed forward process for train, val, and test. Note that the outputs are NOT re-scaled.
Expand All @@ -96,21 +102,23 @@ def forward(self, data: tuple, epoch:int = None, iter_num: int = None, train:boo
train (bool, optional): if in the training process. Defaults to True.
Returns:
tuple: (prediction, real_value)
dict: keys that must be included: inputs, prediction, target
"""

# preprocess
future_data, history_data = data
history_data = self.to_running_device(history_data) # B, L, N, C
future_data = self.to_running_device(future_data) # B, L, N, C
batch_size, length, num_nodes, _ = future_data.shape


history_data = self.select_input_features(history_data)
future_data_4_dec = self.select_input_features(future_data)

# feed forward
pred_values, real_values, mus, sigmas = self.model(history_data=history_data, future_data=future_data_4_dec, train=train)
# post process
prediction = self.select_target_features(pred_values)
real_value = self.select_target_features(real_values)
return prediction, real_value, mus, sigmas
# model forward
model_return = self.model(history_data=history_data, future_data=future_data_4_dec, batch_seen=iter_num, epoch=epoch, train=train)

# parse model return
if isinstance(model_return, torch.Tensor): model_return = {"prediction": model_return}
model_return["inputs"] = self.select_target_features(history_data)
if "target" not in model_return:
model_return["target"] = self.select_target_features(future_data)
return model_return
4 changes: 2 additions & 2 deletions baselines/GTS/METR-LA.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from easydict import EasyDict
from basicts.data import TimeSeriesForecastingDataset
from basicts.utils.serialization import load_pkl
from basicts.runners import SimpleTimeSeriesForecastingRunner

from .arch import GTS
from .loss import gts_loss
from .runner import GTSRunner

CFG = EasyDict()

Expand All @@ -21,7 +21,7 @@

# ================= general ================= #
CFG.DESCRIPTION = "GTS model configuration"
CFG.RUNNER = GTSRunner
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
CFG.DATASET_CLS = TimeSeriesForecastingDataset
CFG.DATASET_NAME = "METR-LA"
CFG.DATASET_TYPE = "Traffic speed"
Expand Down
4 changes: 2 additions & 2 deletions baselines/GTS/PEMS-BAY.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from easydict import EasyDict
from basicts.data import TimeSeriesForecastingDataset
from basicts.utils.serialization import load_pkl
from basicts.runners import SimpleTimeSeriesForecastingRunner

from .arch import GTS
from .loss import gts_loss
from .runner import GTSRunner

CFG = EasyDict()

Expand All @@ -21,7 +21,7 @@

# ================= general ================= #
CFG.DESCRIPTION = "GTS model configuration"
CFG.RUNNER = GTSRunner
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
CFG.DATASET_CLS = TimeSeriesForecastingDataset
CFG.DATASET_NAME = "PEMS-BAY"
CFG.DATASET_TYPE = "Traffic speed"
Expand Down
4 changes: 2 additions & 2 deletions baselines/GTS/PEMS03.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from easydict import EasyDict
from basicts.data import TimeSeriesForecastingDataset
from basicts.utils.serialization import load_pkl
from basicts.runners import SimpleTimeSeriesForecastingRunner

from .arch import GTS
from .loss import gts_loss
from .runner import GTSRunner

CFG = EasyDict()

Expand All @@ -21,7 +21,7 @@

# ================= general ================= #
CFG.DESCRIPTION = "GTS model configuration"
CFG.RUNNER = GTSRunner
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
CFG.DATASET_CLS = TimeSeriesForecastingDataset
CFG.DATASET_NAME = "PEMS03"
CFG.DATASET_TYPE = "Traffic flow"
Expand Down
4 changes: 2 additions & 2 deletions baselines/GTS/PEMS04.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from easydict import EasyDict
from basicts.data import TimeSeriesForecastingDataset
from basicts.utils.serialization import load_pkl
from basicts.runners import SimpleTimeSeriesForecastingRunner

from .arch import GTS
from .loss import gts_loss
from .runner import GTSRunner

CFG = EasyDict()

Expand All @@ -21,7 +21,7 @@

# ================= general ================= #
CFG.DESCRIPTION = "GTS model configuration"
CFG.RUNNER = GTSRunner
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
CFG.DATASET_CLS = TimeSeriesForecastingDataset
CFG.DATASET_NAME = "PEMS04"
CFG.DATASET_TYPE = "Traffic flow"
Expand Down
4 changes: 2 additions & 2 deletions baselines/GTS/PEMS07.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from easydict import EasyDict
from basicts.data import TimeSeriesForecastingDataset
from basicts.utils.serialization import load_pkl
from basicts.runners import SimpleTimeSeriesForecastingRunner

from .arch import GTS
from .loss import gts_loss
from .runner import GTSRunner

CFG = EasyDict()

Expand All @@ -21,7 +21,7 @@

# ================= general ================= #
CFG.DESCRIPTION = "GTS model configuration"
CFG.RUNNER = GTSRunner
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
CFG.DATASET_CLS = TimeSeriesForecastingDataset
CFG.DATASET_NAME = "PEMS07"
CFG.DATASET_TYPE = "Traffic flow"
Expand Down
4 changes: 2 additions & 2 deletions baselines/GTS/PEMS08.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from easydict import EasyDict
from basicts.data import TimeSeriesForecastingDataset
from basicts.utils.serialization import load_pkl
from basicts.runners import SimpleTimeSeriesForecastingRunner

from .arch import GTS
from .loss import gts_loss
from .runner import GTSRunner

CFG = EasyDict()

Expand All @@ -21,7 +21,7 @@

# ================= general ================= #
CFG.DESCRIPTION = "GTS model configuration"
CFG.RUNNER = GTSRunner
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
CFG.DATASET_CLS = TimeSeriesForecastingDataset
CFG.DATASET_NAME = "PEMS08"
CFG.DATASET_TYPE = "Traffic flow"
Expand Down
6 changes: 4 additions & 2 deletions baselines/GTS/arch/gts_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,5 +286,7 @@ def forward(self, history_data, future_data=None, batch_seen=None, epoch=None, *
outputs = self.decoder(encoder_hidden_state, adj, labels, batches_seen=batch_seen)
if batch_seen == 0:
print("Total trainable parameters {}".format(count_parameters(self)))

return outputs.transpose(1, 0).unsqueeze(-1), x.softmax(-1)[:, 0].clone().reshape(self.num_nodes, -1), self.prior_adj
prediction = outputs.transpose(1, 0).unsqueeze(-1)
pred_adj = x.softmax(-1)[:, 0].clone().reshape(self.num_nodes, -1)
prior_adj = self.prior_adj
return {"prediction": prediction, "pred_adj": pred_adj, "prior_adj": prior_adj}
4 changes: 2 additions & 2 deletions baselines/GTS/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from basicts.losses import masked_mae


def gts_loss(prediction, real_value, pred_adj, prior_adj, null_val = np.nan):
def gts_loss(prediction, target, pred_adj, prior_adj, null_val = np.nan):
# graph loss
prior_label = prior_adj.view(prior_adj.shape[0] * prior_adj.shape[1]).to(pred_adj.device)
pred_label = pred_adj.view(pred_adj.shape[0] * pred_adj.shape[1])
graph_loss_function = torch.nn.BCELoss()
loss_g = graph_loss_function(pred_label, prior_label)
# regression loss
loss_r = masked_mae(prediction, real_value, null_val=null_val)
loss_r = masked_mae(prediction, target, null_val=null_val)
# total loss
loss = loss_r + loss_g
return loss
1 change: 0 additions & 1 deletion baselines/GTS/runner/__init__.py

This file was deleted.

Loading

0 comments on commit 9129716

Please sign in to comment.