From 9129716be983e4cfb72e3f299121e3a229636bab Mon Sep 17 00:00:00 2001 From: Zezhi Shao <864453277@qq.com> Date: Mon, 11 Dec 2023 03:40:28 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=F0=9F=8E=B8=20Compatible=20with=20Basi?= =?UTF-8?q?cTS=200.3.5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- baselines/DGCRN/runner/dgcrn_runner.py | 17 ++-- baselines/DeepAR/arch/deepar.py | 5 +- baselines/DeepAR/loss/gaussian.py | 18 ++-- baselines/DeepAR/runner/deepar_runner.py | 68 +++++++------- baselines/GTS/METR-LA.py | 4 +- baselines/GTS/PEMS-BAY.py | 4 +- baselines/GTS/PEMS03.py | 4 +- baselines/GTS/PEMS04.py | 4 +- baselines/GTS/PEMS07.py | 4 +- baselines/GTS/PEMS08.py | 4 +- baselines/GTS/arch/gts_arch.py | 6 +- baselines/GTS/loss/loss.py | 4 +- baselines/GTS/runner/__init__.py | 1 - baselines/GTS/runner/gts_runner.py | 74 --------------- baselines/LightGBM/PEMS08.ipynb | 8 +- baselines/MTGNN/runner/mtgnn_runner.py | 18 ++-- baselines/MegaCRN/MegaCRN_METR-LA.py | 4 +- baselines/MegaCRN/arch/megacrn_arch.py | 2 +- baselines/MegaCRN/loss/loss.py | 6 +- baselines/MegaCRN/runner/__init__.py | 1 - baselines/MegaCRN/runner/megacrn_runner.py | 70 --------------- baselines/STEP/README.md | 1 + baselines/STEP/STEP_PEMS08.py | 2 +- baselines/STEP/step_arch/step.py | 7 +- baselines/STEP/step_loss/step_loss.py | 12 +-- baselines/STEP/step_runner/__init__.py | 3 +- baselines/STEP/step_runner/step_runner.py | 15 ++-- baselines/STEP/step_runner/tsformer_runner.py | 90 ------------------- baselines/STWave/loss.py | 6 +- 29 files changed, 118 insertions(+), 344 deletions(-) delete mode 100644 baselines/GTS/runner/__init__.py delete mode 100644 baselines/GTS/runner/gts_runner.py delete mode 100644 baselines/MegaCRN/runner/__init__.py delete mode 100644 baselines/MegaCRN/runner/megacrn_runner.py delete mode 100644 baselines/STEP/step_runner/tsformer_runner.py diff --git a/baselines/DGCRN/runner/dgcrn_runner.py b/baselines/DGCRN/runner/dgcrn_runner.py index a38c67b2..29bd4485 100644 --- a/baselines/DGCRN/runner/dgcrn_runner.py +++ b/baselines/DGCRN/runner/dgcrn_runner.py @@ -16,7 +16,7 @@ def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: b train (bool, optional): if in the training process. Defaults to True. Returns: - tuple: (prediction, real_value) + dict: keys that must be included: inputs, prediction, target """ # preprocess @@ -35,12 +35,11 @@ def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: b # customized curriculum learning task_level = self.curriculum_learning(epoch) - prediction_data = self.model(history_data=history_data, future_data=future_data_4_dec, batch_seen=iter_num, epoch=epoch, train=train, task_level=task_level) - # feed forward - assert list(prediction_data.shape)[:3] == [batch_size, length, num_nodes], \ + model_return = self.model(history_data=history_data, future_data=future_data_4_dec, batch_seen=iter_num, epoch=epoch, train=train, task_level=task_level) + # parse model return + if isinstance(model_return, torch.Tensor): model_return = {"prediction": model_return} + model_return["inputs"] = self.select_target_features(history_data) + model_return["target"] = self.select_target_features(future_data) + assert list(model_return["prediction"].shape)[:3] == [batch_size, length, num_nodes], \ "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" - # post process - prediction = self.select_target_features(prediction_data) - real_value = self.select_target_features(future_data) - return prediction, real_value - \ No newline at end of file + return model_return diff --git a/baselines/DeepAR/arch/deepar.py b/baselines/DeepAR/arch/deepar.py index 147869c8..14a66cc5 100644 --- a/baselines/DeepAR/arch/deepar.py +++ b/baselines/DeepAR/arch/deepar.py @@ -73,8 +73,6 @@ def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, train: for t in range(1, len_in + len_out): if not (t > len_in and not train): # not in the decoding stage when inferecing history_next = input_feat_full[:, t-1:t, :, 0:1] - else: - a = 1 embed_feat = self.input_embed(history_next) covar_feat = covar_feat_full[:, t:t+1, :, :] if self.use_ts_id: @@ -98,4 +96,5 @@ def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, train: mus = torch.concat(mus, dim=1) sigmas = torch.concat(sigmas, dim=1) reals = input_feat_full[:, -preds.shape[1]:, :, :] - return preds, reals, mus, sigmas + + return {"prediction": preds, "target": reals, "mus": mus, "sigmas": sigmas} diff --git a/baselines/DeepAR/loss/gaussian.py b/baselines/DeepAR/loss/gaussian.py index c278f4b0..9d5c5b96 100644 --- a/baselines/DeepAR/loss/gaussian.py +++ b/baselines/DeepAR/loss/gaussian.py @@ -2,28 +2,28 @@ import numpy as np -def gaussian_loss(prediction, real_value, mu, sigma, null_val = np.nan): - """Masked gaussian loss. Kindly note that the gaussian loss is calculated based on mu, sigma, and real_value. The prediction is sampled from N(mu, sigma), and is not used in the loss calculation (it will be used in the metrics calculation). +def gaussian_loss(prediction, target, mus, sigmas, null_val = np.nan): + """Masked gaussian loss. Kindly note that the gaussian loss is calculated based on mu, sigma, and target. The prediction is sampled from N(mu, sigma), and is not used in the loss calculation (it will be used in the metrics calculation). Args: prediction (torch.Tensor): prediction of model. [B, L, N, 1]. - real_value (torch.Tensor): ground truth. [B, L, N, 1]. - mu (torch.Tensor): the mean of gaussian distribution. [B, L, N, 1]. - sigma (torch.Tensor): the std of gaussian distribution. [B, L, N, 1] + target (torch.Tensor): ground truth. [B, L, N, 1]. + mus (torch.Tensor): the mean of gaussian distribution. [B, L, N, 1]. + sigmas (torch.Tensor): the std of gaussian distribution. [B, L, N, 1] null_val (optional): null value. Defaults to np.nan. """ # mask if np.isnan(null_val): - mask = ~torch.isnan(real_value) + mask = ~torch.isnan(target) else: eps = 5e-5 - mask = ~torch.isclose(real_value, torch.tensor(null_val).expand_as(real_value).to(real_value.device), atol=eps, rtol=0.) + mask = ~torch.isclose(target, torch.tensor(null_val).expand_as(target).to(target.device), atol=eps, rtol=0.) mask = mask.float() mask /= torch.mean((mask)) mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) - distribution = torch.distributions.Normal(mu, sigma) - likelihood = distribution.log_prob(real_value) + distribution = torch.distributions.Normal(mus, sigmas) + likelihood = distribution.log_prob(target) likelihood = likelihood * mask loss_g = -torch.mean(likelihood) return loss_g diff --git a/baselines/DeepAR/runner/deepar_runner.py b/baselines/DeepAR/runner/deepar_runner.py index 7fc12c4d..275a30b5 100644 --- a/baselines/DeepAR/runner/deepar_runner.py +++ b/baselines/DeepAR/runner/deepar_runner.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict import torch from basicts.data.registry import SCALER_REGISTRY from easytorch.utils.dist import master_only @@ -42,22 +42,25 @@ def select_target_features(self, data: torch.Tensor) -> torch.Tensor: data = data[:, :, :, self.target_features] return data - def rescale_data(self, input_data: List[torch.Tensor]) -> List[torch.Tensor]: + def rescale_data(self, input_data: Dict) -> Dict: """Rescale data. Args: - data (List[torch.Tensor]): list of data to be re-scaled. + data (Dict): Dict of data to be re-scaled. Returns: - List[torch.Tensor]: list of re-scaled data. + Dict: Dict re-scaled data. """ - prediction, real_value, mus, sigmas = input_data + if self.if_rescale: - prediction = SCALER_REGISTRY.get(self.scaler["func"])(prediction, **self.scaler["args"]) - real_value = SCALER_REGISTRY.get(self.scaler["func"])(real_value, **self.scaler["args"]) - mus = SCALER_REGISTRY.get(self.scaler["func"])(mus, **self.scaler["args"]) - sigmas = SCALER_REGISTRY.get(self.scaler["func"])(sigmas, **self.scaler["args"]) - return [prediction, real_value, mus, sigmas] + input_data["inputs"] = SCALER_REGISTRY.get(self.scaler["func"])(input_data["inputs"], **self.scaler["args"]) + input_data["prediction"] = SCALER_REGISTRY.get(self.scaler["func"])(input_data["prediction"], **self.scaler["args"]) + input_data["target"] = SCALER_REGISTRY.get(self.scaler["func"])(input_data["target"], **self.scaler["args"]) + if "mus" in input_data.keys(): + input_data["mus"] = SCALER_REGISTRY.get(self.scaler["func"])(input_data["mus"], **self.scaler["args"]) + if "sigmas" in input_data.keys(): + input_data["sigmas"] = SCALER_REGISTRY.get(self.scaler["func"])(input_data["sigmas"], **self.scaler["args"]) + return input_data @torch.no_grad() @master_only @@ -69,22 +72,25 @@ def test(self): """ # test loop - prediction = [] - real_value = [] + prediction =[] + target = [] + inputs = [] for _, data in enumerate(self.test_data_loader): - forward_return = list(self.forward(data, epoch=None, iter_num=None, train=False)) + forward_return = self.forward(data, epoch=None, iter_num=None, train=False) if not self.if_evaluate_on_gpu: - forward_return[0], forward_return[1] = forward_return[0].detach().cpu(), forward_return[1].detach().cpu() - prediction.append(forward_return[0]) # preds = forward_return[0] - real_value.append(forward_return[1]) # testy = forward_return[1] + forward_return["prediction"] = forward_return["prediction"].detach().cpu() + forward_return["target"] = forward_return["target"].detach().cpu() + forward_return["inputs"] = forward_return["inputs"].detach().cpu() + prediction.append(forward_return["prediction"]) + target.append(forward_return["target"]) + inputs.append(forward_return["inputs"]) prediction = torch.cat(prediction, dim=0) - real_value = torch.cat(real_value, dim=0) + target = torch.cat(target, dim=0) + inputs = torch.cat(inputs, dim=0) # re-scale data - if self.if_rescale: - prediction = SCALER_REGISTRY.get(self.scaler["func"])(prediction, **self.scaler["args"])[:, -self.output_seq_len:, :, :] - real_value = SCALER_REGISTRY.get(self.scaler["func"])(real_value, **self.scaler["args"])[:, -self.output_seq_len:, :, :] + returns_all = self.rescale_data({"prediction": prediction[:, -self.output_seq_len:, :, :], "target": target[:, -self.output_seq_len:, :, :], "inputs": inputs}) # evaluate - self.evaluate(prediction, real_value) + self.evaluate(returns_all) def forward(self, data: tuple, epoch:int = None, iter_num: int = None, train:bool = True, **kwargs) -> tuple: """feed forward process for train, val, and test. Note that the outputs are NOT re-scaled. @@ -96,21 +102,23 @@ def forward(self, data: tuple, epoch:int = None, iter_num: int = None, train:boo train (bool, optional): if in the training process. Defaults to True. Returns: - tuple: (prediction, real_value) + dict: keys that must be included: inputs, prediction, target """ # preprocess future_data, history_data = data history_data = self.to_running_device(history_data) # B, L, N, C future_data = self.to_running_device(future_data) # B, L, N, C - batch_size, length, num_nodes, _ = future_data.shape - + history_data = self.select_input_features(history_data) future_data_4_dec = self.select_input_features(future_data) - # feed forward - pred_values, real_values, mus, sigmas = self.model(history_data=history_data, future_data=future_data_4_dec, train=train) - # post process - prediction = self.select_target_features(pred_values) - real_value = self.select_target_features(real_values) - return prediction, real_value, mus, sigmas + # model forward + model_return = self.model(history_data=history_data, future_data=future_data_4_dec, batch_seen=iter_num, epoch=epoch, train=train) + + # parse model return + if isinstance(model_return, torch.Tensor): model_return = {"prediction": model_return} + model_return["inputs"] = self.select_target_features(history_data) + if "target" not in model_return: + model_return["target"] = self.select_target_features(future_data) + return model_return diff --git a/baselines/GTS/METR-LA.py b/baselines/GTS/METR-LA.py index c73239a4..d0e0794f 100644 --- a/baselines/GTS/METR-LA.py +++ b/baselines/GTS/METR-LA.py @@ -6,10 +6,10 @@ from easydict import EasyDict from basicts.data import TimeSeriesForecastingDataset from basicts.utils.serialization import load_pkl +from basicts.runners import SimpleTimeSeriesForecastingRunner from .arch import GTS from .loss import gts_loss -from .runner import GTSRunner CFG = EasyDict() @@ -21,7 +21,7 @@ # ================= general ================= # CFG.DESCRIPTION = "GTS model configuration" -CFG.RUNNER = GTSRunner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner CFG.DATASET_CLS = TimeSeriesForecastingDataset CFG.DATASET_NAME = "METR-LA" CFG.DATASET_TYPE = "Traffic speed" diff --git a/baselines/GTS/PEMS-BAY.py b/baselines/GTS/PEMS-BAY.py index 6de438cf..d6e29668 100644 --- a/baselines/GTS/PEMS-BAY.py +++ b/baselines/GTS/PEMS-BAY.py @@ -6,10 +6,10 @@ from easydict import EasyDict from basicts.data import TimeSeriesForecastingDataset from basicts.utils.serialization import load_pkl +from basicts.runners import SimpleTimeSeriesForecastingRunner from .arch import GTS from .loss import gts_loss -from .runner import GTSRunner CFG = EasyDict() @@ -21,7 +21,7 @@ # ================= general ================= # CFG.DESCRIPTION = "GTS model configuration" -CFG.RUNNER = GTSRunner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner CFG.DATASET_CLS = TimeSeriesForecastingDataset CFG.DATASET_NAME = "PEMS-BAY" CFG.DATASET_TYPE = "Traffic speed" diff --git a/baselines/GTS/PEMS03.py b/baselines/GTS/PEMS03.py index f8dcdc78..6a4a0f48 100644 --- a/baselines/GTS/PEMS03.py +++ b/baselines/GTS/PEMS03.py @@ -6,10 +6,10 @@ from easydict import EasyDict from basicts.data import TimeSeriesForecastingDataset from basicts.utils.serialization import load_pkl +from basicts.runners import SimpleTimeSeriesForecastingRunner from .arch import GTS from .loss import gts_loss -from .runner import GTSRunner CFG = EasyDict() @@ -21,7 +21,7 @@ # ================= general ================= # CFG.DESCRIPTION = "GTS model configuration" -CFG.RUNNER = GTSRunner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner CFG.DATASET_CLS = TimeSeriesForecastingDataset CFG.DATASET_NAME = "PEMS03" CFG.DATASET_TYPE = "Traffic flow" diff --git a/baselines/GTS/PEMS04.py b/baselines/GTS/PEMS04.py index f8c57637..69f8160b 100644 --- a/baselines/GTS/PEMS04.py +++ b/baselines/GTS/PEMS04.py @@ -6,10 +6,10 @@ from easydict import EasyDict from basicts.data import TimeSeriesForecastingDataset from basicts.utils.serialization import load_pkl +from basicts.runners import SimpleTimeSeriesForecastingRunner from .arch import GTS from .loss import gts_loss -from .runner import GTSRunner CFG = EasyDict() @@ -21,7 +21,7 @@ # ================= general ================= # CFG.DESCRIPTION = "GTS model configuration" -CFG.RUNNER = GTSRunner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner CFG.DATASET_CLS = TimeSeriesForecastingDataset CFG.DATASET_NAME = "PEMS04" CFG.DATASET_TYPE = "Traffic flow" diff --git a/baselines/GTS/PEMS07.py b/baselines/GTS/PEMS07.py index 62391cff..ce25f176 100644 --- a/baselines/GTS/PEMS07.py +++ b/baselines/GTS/PEMS07.py @@ -6,10 +6,10 @@ from easydict import EasyDict from basicts.data import TimeSeriesForecastingDataset from basicts.utils.serialization import load_pkl +from basicts.runners import SimpleTimeSeriesForecastingRunner from .arch import GTS from .loss import gts_loss -from .runner import GTSRunner CFG = EasyDict() @@ -21,7 +21,7 @@ # ================= general ================= # CFG.DESCRIPTION = "GTS model configuration" -CFG.RUNNER = GTSRunner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner CFG.DATASET_CLS = TimeSeriesForecastingDataset CFG.DATASET_NAME = "PEMS07" CFG.DATASET_TYPE = "Traffic flow" diff --git a/baselines/GTS/PEMS08.py b/baselines/GTS/PEMS08.py index 97fe243a..b53f68e7 100644 --- a/baselines/GTS/PEMS08.py +++ b/baselines/GTS/PEMS08.py @@ -6,10 +6,10 @@ from easydict import EasyDict from basicts.data import TimeSeriesForecastingDataset from basicts.utils.serialization import load_pkl +from basicts.runners import SimpleTimeSeriesForecastingRunner from .arch import GTS from .loss import gts_loss -from .runner import GTSRunner CFG = EasyDict() @@ -21,7 +21,7 @@ # ================= general ================= # CFG.DESCRIPTION = "GTS model configuration" -CFG.RUNNER = GTSRunner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner CFG.DATASET_CLS = TimeSeriesForecastingDataset CFG.DATASET_NAME = "PEMS08" CFG.DATASET_TYPE = "Traffic flow" diff --git a/baselines/GTS/arch/gts_arch.py b/baselines/GTS/arch/gts_arch.py index f255526f..6d0a3f96 100644 --- a/baselines/GTS/arch/gts_arch.py +++ b/baselines/GTS/arch/gts_arch.py @@ -286,5 +286,7 @@ def forward(self, history_data, future_data=None, batch_seen=None, epoch=None, * outputs = self.decoder(encoder_hidden_state, adj, labels, batches_seen=batch_seen) if batch_seen == 0: print("Total trainable parameters {}".format(count_parameters(self))) - - return outputs.transpose(1, 0).unsqueeze(-1), x.softmax(-1)[:, 0].clone().reshape(self.num_nodes, -1), self.prior_adj + prediction = outputs.transpose(1, 0).unsqueeze(-1) + pred_adj = x.softmax(-1)[:, 0].clone().reshape(self.num_nodes, -1) + prior_adj = self.prior_adj + return {"prediction": prediction, "pred_adj": pred_adj, "prior_adj": prior_adj} diff --git a/baselines/GTS/loss/loss.py b/baselines/GTS/loss/loss.py index 2ad0a866..c9af8f63 100644 --- a/baselines/GTS/loss/loss.py +++ b/baselines/GTS/loss/loss.py @@ -3,14 +3,14 @@ from basicts.losses import masked_mae -def gts_loss(prediction, real_value, pred_adj, prior_adj, null_val = np.nan): +def gts_loss(prediction, target, pred_adj, prior_adj, null_val = np.nan): # graph loss prior_label = prior_adj.view(prior_adj.shape[0] * prior_adj.shape[1]).to(pred_adj.device) pred_label = pred_adj.view(pred_adj.shape[0] * pred_adj.shape[1]) graph_loss_function = torch.nn.BCELoss() loss_g = graph_loss_function(pred_label, prior_label) # regression loss - loss_r = masked_mae(prediction, real_value, null_val=null_val) + loss_r = masked_mae(prediction, target, null_val=null_val) # total loss loss = loss_r + loss_g return loss diff --git a/baselines/GTS/runner/__init__.py b/baselines/GTS/runner/__init__.py deleted file mode 100644 index ece7ae16..00000000 --- a/baselines/GTS/runner/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .gts_runner import GTSRunner \ No newline at end of file diff --git a/baselines/GTS/runner/gts_runner.py b/baselines/GTS/runner/gts_runner.py deleted file mode 100644 index 5da70194..00000000 --- a/baselines/GTS/runner/gts_runner.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch - -from basicts.runners import BaseTimeSeriesForecastingRunner - - -class GTSRunner(BaseTimeSeriesForecastingRunner): - def __init__(self, cfg: dict): - super().__init__(cfg) - self.forward_features = cfg["MODEL"].get("FORWARD_FEATURES", None) - self.target_features = cfg["MODEL"].get("TARGET_FEATURES", None) - - def select_input_features(self, data: torch.Tensor) -> torch.Tensor: - """Select input features and reshape data to fit the target model. - - Args: - data (torch.Tensor): input history data, shape [B, L, N, C]. - - Returns: - torch.Tensor: reshaped data - """ - - # select feature using self.forward_features - if self.forward_features is not None: - data = data[:, :, :, self.forward_features] - return data - - def select_target_features(self, data: torch.Tensor) -> torch.Tensor: - """Select target features and reshape data back to the BasicTS framework - - Args: - data (torch.Tensor): prediction of the model with arbitrary shape. - - Returns: - torch.Tensor: reshaped data with shape [B, L, N, C] - """ - - # select feature using self.target_features - data = data[:, :, :, self.target_features] - return data - - def forward(self, data: tuple, epoch:int = None, iter_num: int = None, train:bool = True, **kwargs) -> tuple: - """feed forward process for train, val, and test. Note that the outputs are NOT re-scaled. - - Args: - data (tuple): data (future data, history data). [B, L, N, C] for each of them - epoch (int, optional): epoch number. Defaults to None. - iter_num (int, optional): iteration number. Defaults to None. - train (bool, optional): if in the training process. Defaults to True. - - Returns: - tuple: (prediction, real_value) - """ - - # preprocess - future_data, history_data = data - history_data = self.to_running_device(history_data) # B, L, N, C - future_data = self.to_running_device(future_data) # B, L, N, C - batch_size, length, num_nodes, _ = future_data.shape - - history_data = self.select_input_features(history_data) - if train: - # teacher forcing only use the first dimension. - future_data_4_dec = future_data[..., [0]] - else: - future_data_4_dec = None - - # feed forward - prediction_data, pred_adj, prior_adj = self.model(history_data=history_data, future_data=future_data_4_dec, batch_seen=iter_num, epoch=epoch) - assert list(prediction_data.shape)[:3] == [batch_size, length, num_nodes], \ - "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" - # post process - prediction = self.select_target_features(prediction_data) - real_value = self.select_target_features(future_data) - return prediction, real_value, pred_adj, prior_adj diff --git a/baselines/LightGBM/PEMS08.ipynb b/baselines/LightGBM/PEMS08.ipynb index 6fa03a1f..6f8a1f67 100644 --- a/baselines/LightGBM/PEMS08.ipynb +++ b/baselines/LightGBM/PEMS08.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -66,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ diff --git a/baselines/MTGNN/runner/mtgnn_runner.py b/baselines/MTGNN/runner/mtgnn_runner.py index 5435ad25..c6129dd2 100644 --- a/baselines/MTGNN/runner/mtgnn_runner.py +++ b/baselines/MTGNN/runner/mtgnn_runner.py @@ -71,14 +71,16 @@ def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: b history_data = self.select_input_features(history_data) - prediction_data = self.model( - history_data=history_data, idx=idx, batch_seen=iter_num, epoch=epoch) # B, L, N, C - assert list(prediction_data.shape)[:3] == [ - batch_size, seq_len, num_nodes], "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" - # post process - prediction = self.select_target_features(prediction_data) - real_value = self.select_target_features(future_data) - return prediction, real_value + # model forward + model_return = self.model(history_data=history_data, idx=idx, batch_seen=iter_num, epoch=epoch) # B, L, N, C + + # parse model return + if isinstance(model_return, torch.Tensor): model_return = {"prediction": model_return} + model_return["inputs"] = self.select_target_features(history_data) + model_return["target"] = self.select_target_features(future_data) + assert list(model_return["prediction"].shape)[:3] == [batch_size, seq_len, num_nodes], \ + "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" + return model_return def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor: """It must be implement to define training detail. diff --git a/baselines/MegaCRN/MegaCRN_METR-LA.py b/baselines/MegaCRN/MegaCRN_METR-LA.py index eaa62848..9427e4da 100644 --- a/baselines/MegaCRN/MegaCRN_METR-LA.py +++ b/baselines/MegaCRN/MegaCRN_METR-LA.py @@ -5,16 +5,16 @@ sys.path.append(os.path.abspath(__file__ + "/../../..")) from easydict import EasyDict from basicts.data import TimeSeriesForecastingDataset +from basicts.runners import SimpleTimeSeriesForecastingRunner from .arch import MegaCRN from .loss import megacrn_loss -from .runner import MegaCRNRunner CFG = EasyDict() # ================= general ================= # CFG.DESCRIPTION = "MegaCRN model configuration" -CFG.RUNNER = MegaCRNRunner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner CFG.DATASET_CLS = TimeSeriesForecastingDataset CFG.DATASET_NAME = "METR-LA" CFG.DATASET_TYPE = "Traffic speed" diff --git a/baselines/MegaCRN/arch/megacrn_arch.py b/baselines/MegaCRN/arch/megacrn_arch.py index 03da2377..b3196a0f 100644 --- a/baselines/MegaCRN/arch/megacrn_arch.py +++ b/baselines/MegaCRN/arch/megacrn_arch.py @@ -217,4 +217,4 @@ def forward(self, history_data, future_data, batch_seen=None, epoch=None, **kwar go = labels[:, t, ...] output = torch.stack(out, dim=1) - return output, h_att, query, pos, neg + return {'prediction': output, 'query': query, 'pos': pos, 'neg': neg} diff --git a/baselines/MegaCRN/loss/loss.py b/baselines/MegaCRN/loss/loss.py index 82d5de4e..b7653ee5 100644 --- a/baselines/MegaCRN/loss/loss.py +++ b/baselines/MegaCRN/loss/loss.py @@ -1,15 +1,13 @@ -import torch from torch import nn -import numpy as np from basicts.losses import masked_mae -def megacrn_loss(prediction, real_value, query, pos, neg, null_val): +def megacrn_loss(prediction, target, query, pos, neg, null_val): separate_loss = nn.TripletMarginLoss(margin=1.0) compact_loss = nn.MSELoss() criterion = masked_mae - loss1 = criterion(prediction, real_value, null_val) + loss1 = criterion(prediction, target, null_val) loss2 = separate_loss(query, pos.detach(), neg.detach()) loss3 = compact_loss(query, pos.detach()) loss = loss1 + 0.01 * loss2 + 0.01 * loss3 diff --git a/baselines/MegaCRN/runner/__init__.py b/baselines/MegaCRN/runner/__init__.py deleted file mode 100644 index dafa6f3a..00000000 --- a/baselines/MegaCRN/runner/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .megacrn_runner import MegaCRNRunner \ No newline at end of file diff --git a/baselines/MegaCRN/runner/megacrn_runner.py b/baselines/MegaCRN/runner/megacrn_runner.py deleted file mode 100644 index 6001dc26..00000000 --- a/baselines/MegaCRN/runner/megacrn_runner.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch - -from basicts.runners import BaseTimeSeriesForecastingRunner - - -class MegaCRNRunner(BaseTimeSeriesForecastingRunner): - def __init__(self, cfg: dict): - super().__init__(cfg) - self.forward_features = cfg["MODEL"].get("FORWARD_FEATURES", None) - self.target_features = cfg["MODEL"].get("TARGET_FEATURES", None) - - def select_input_features(self, data: torch.Tensor) -> torch.Tensor: - """Select input features and reshape data to fit the target model. - - Args: - data (torch.Tensor): input history data, shape [B, L, N, C]. - - Returns: - torch.Tensor: reshaped data - """ - - # select feature using self.forward_features - if self.forward_features is not None: - data = data[:, :, :, self.forward_features] - return data - - def select_target_features(self, data: torch.Tensor) -> torch.Tensor: - """Select target features and reshape data back to the BasicTS framework - - Args: - data (torch.Tensor): prediction of the model with arbitrary shape. - - Returns: - torch.Tensor: reshaped data with shape [B, L, N, C] - """ - - # select feature using self.target_features - data = data[:, :, :, self.target_features] - return data - - def forward(self, data: tuple, epoch:int = None, iter_num: int = None, train:bool = True, **kwargs) -> tuple: - """feed forward process for train, val, and test. Note that the outputs are NOT re-scaled. - - Args: - data (tuple): data (future data, history data). [B, L, N, C] for each of them - epoch (int, optional): epoch number. Defaults to None. - iter_num (int, optional): iteration number. Defaults to None. - train (bool, optional): if in the training process. Defaults to True. - - Returns: - tuple: (prediction, real_value) - """ - - # preprocess - future_data, history_data = data - history_data = self.to_running_device(history_data) # B, L, N, C - future_data = self.to_running_device(future_data) # B, L, N, C - batch_size, length, num_nodes, _ = future_data.shape - - history_data = self.select_input_features(history_data) - future_data = self.select_input_features(future_data) - - # feed forward - prediction_data, h_att, query, pos, neg = self.model(history_data=history_data, future_data=future_data, batch_seen=iter_num, epoch=epoch) - assert list(prediction_data.shape)[:3] == [batch_size, length, num_nodes], \ - "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" - # post process - prediction = self.select_target_features(prediction_data) - real_value = self.select_target_features(future_data) - return prediction, real_value, query, pos, neg diff --git a/baselines/STEP/README.md b/baselines/STEP/README.md index 224436d8..f8b79f8c 100644 --- a/baselines/STEP/README.md +++ b/baselines/STEP/README.md @@ -1 +1,2 @@ STEP requires a pre-trained TSFormer model. You can download them from [here](https://github.com/zezhishao/STEP/tree/github/tsformer_ckpt) and place them in the `./ckpts/` folder. +In addition, STEP requires `timm` package. You can install it by `pip install timm`. diff --git a/baselines/STEP/STEP_PEMS08.py b/baselines/STEP/STEP_PEMS08.py index 4d12db6f..67a0c816 100644 --- a/baselines/STEP/STEP_PEMS08.py +++ b/baselines/STEP/STEP_PEMS08.py @@ -27,7 +27,7 @@ CFG.DATASET_ARGS = { "seq_len": 288 * 7 * 2 } -CFG.GPU_NUM = 1 +CFG.GPU_NUM = 4 CFG.NULL_VAL = 0.0 # ================= environment ================= # diff --git a/baselines/STEP/step_arch/step.py b/baselines/STEP/step_arch/step.py index da22d1bd..13c71df3 100644 --- a/baselines/STEP/step_arch/step.py +++ b/baselines/STEP/step_arch/step.py @@ -69,4 +69,9 @@ def forward(self, history_data: torch.Tensor, long_history_data: torch.Tensor, f gsl_coefficient = 1 / (int(epoch/6)+1) else: gsl_coefficient = 0 - return y_hat.unsqueeze(-1), bernoulli_unnorm.softmax(-1)[..., 0].clone().reshape(batch_size, num_nodes, num_nodes), adj_knn, gsl_coefficient + + prediction = y_hat.unsqueeze(-1) + pred_adj = bernoulli_unnorm.softmax(-1)[..., 0].clone().reshape(batch_size, num_nodes, num_nodes) + prior_adj = adj_knn + gsl_coefficient = gsl_coefficient + return {"prediction": prediction, "pred_adj": pred_adj, "prior_adj": prior_adj, "gsl_coefficient": gsl_coefficient} diff --git a/baselines/STEP/step_loss/step_loss.py b/baselines/STEP/step_loss/step_loss.py index 8c653f8d..86c009b0 100644 --- a/baselines/STEP/step_loss/step_loss.py +++ b/baselines/STEP/step_loss/step_loss.py @@ -2,15 +2,15 @@ from torch import nn from basicts.losses import masked_mae -def step_loss(prediction, real_value, theta, priori_adj, gsl_coefficient, null_val=np.nan): +def step_loss(prediction, target, pred_adj, prior_adj, gsl_coefficient, null_val=np.nan): # graph structure learning loss - B, N, N = theta.shape - theta = theta.view(B, N*N) - tru = priori_adj.view(B, N*N) + B, N, N = pred_adj.shape + pred_adj = pred_adj.view(B, N*N) + tru = prior_adj.view(B, N*N) BCE_loss = nn.BCELoss() - loss_graph = BCE_loss(theta, tru) + loss_graph = BCE_loss(pred_adj, tru) # prediction loss - loss_pred = masked_mae(preds=prediction, labels=real_value, null_val=null_val) + loss_pred = masked_mae(prediction=prediction, target=target, null_val=null_val) # final loss loss = loss_pred + loss_graph * gsl_coefficient return loss diff --git a/baselines/STEP/step_runner/__init__.py b/baselines/STEP/step_runner/__init__.py index de471b5c..dfd4ad9c 100644 --- a/baselines/STEP/step_runner/__init__.py +++ b/baselines/STEP/step_runner/__init__.py @@ -1,4 +1,3 @@ -from .tsformer_runner import TSFormerRunner from .step_runner import STEPRunner -__all__ = ["TSFormerRunner", "STEPRunner"] +__all__ = ["STEPRunner"] diff --git a/baselines/STEP/step_runner/step_runner.py b/baselines/STEP/step_runner/step_runner.py index b324b5d0..c41f928f 100644 --- a/baselines/STEP/step_runner/step_runner.py +++ b/baselines/STEP/step_runner/step_runner.py @@ -62,13 +62,10 @@ def forward(self, data: tuple, epoch:int = None, iter_num: int = None, train:boo long_history_data = self.select_input_features(long_history_data) # feed forward - prediction, pred_adj, prior_adj, gsl_coefficient = self.model(history_data=history_data, long_history_data=long_history_data, future_data=None, batch_seen=iter_num, epoch=epoch) + model_return = self.model(history_data=history_data, long_history_data=long_history_data, future_data=None, batch_seen=iter_num, epoch=epoch) - batch_size, length, num_nodes, _ = future_data.shape - assert list(prediction.shape)[:3] == [batch_size, length, num_nodes], \ - "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" - - # post process - prediction = self.select_target_features(prediction) - real_value = self.select_target_features(future_data) - return prediction, real_value, pred_adj, prior_adj, gsl_coefficient + # parse model return + if isinstance(model_return, torch.Tensor): model_return = {"prediction": model_return} + model_return["inputs"] = self.select_target_features(history_data) + model_return["target"] = self.select_target_features(future_data) + return model_return diff --git a/baselines/STEP/step_runner/tsformer_runner.py b/baselines/STEP/step_runner/tsformer_runner.py deleted file mode 100644 index 1fd1aa9f..00000000 --- a/baselines/STEP/step_runner/tsformer_runner.py +++ /dev/null @@ -1,90 +0,0 @@ -import torch - -from easytorch.utils.dist import master_only -from basicts.data.registry import SCALER_REGISTRY -from basicts.runners import BaseTimeSeriesForecastingRunner - - -class TSFormerRunner(BaseTimeSeriesForecastingRunner): - def __init__(self, cfg: dict): - super().__init__(cfg) - self.forward_features = cfg["MODEL"].get("FORWARD_FEATURES", None) - self.target_features = cfg["MODEL"].get("TARGET_FEATURES", None) - - def select_input_features(self, data: torch.Tensor) -> torch.Tensor: - """Select input features and reshape data to fit the target model. - - Args: - data (torch.Tensor): input history data, shape [B, L, N, C]. - - Returns: - torch.Tensor: reshaped data - """ - - # select feature using self.forward_features - if self.forward_features is not None: - data = data[:, :, :, self.forward_features] - return data - - def select_target_features(self, data: torch.Tensor) -> torch.Tensor: - """Select target features and reshape data back to the BasicTS framework - - Args: - data (torch.Tensor): prediction of the model with arbitrary shape. - - Returns: - torch.Tensor: reshaped data with shape [B, L, N, C] - """ - - # select feature using self.target_features - data = data[:, :, :, self.target_features] - return data - - def forward(self, data: tuple, epoch:int = None, iter_num: int = None, train:bool = True, **kwargs) -> tuple: - """feed forward process for train, val, and test. Note that the outputs are NOT re-scaled. - - Args: - data (tuple): data (future data, history data). [B, L, N, C] for each of them - epoch (int, optional): epoch number. Defaults to None. - iter_num (int, optional): iteration number. Defaults to None. - train (bool, optional): if in the training process. Defaults to True. - - Returns: - tuple: (prediction, real_value) - """ - - # preprocess - future_data, history_data = data - history_data = self.to_running_device(history_data) # B, L, N, C - future_data = self.to_running_device(future_data) # B, L, N, C - batch_size, length, num_nodes, _ = future_data.shape - - history_data = self.select_input_features(history_data) - - # feed forward - reconstruction_masked_tokens, label_masked_tokens = self.model(history_data=history_data, future_data=None, batch_seen=iter_num, epoch=epoch) - # assert list(prediction_data.shape)[:3] == [batch_size, length, num_nodes], \ - # "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" - # post process - # prediction = self.select_target_features(prediction_data) - # real_value = self.select_target_features(future_data) - return reconstruction_masked_tokens, label_masked_tokens - - @torch.no_grad() - @master_only - def test(self): - """Evaluate the model. - - Args: - train_epoch (int, optional): current epoch if in training process. - """ - - for _, data in enumerate(self.test_data_loader): - forward_return = self.forward(data=data, epoch=None, iter_num=None, train=False) - # re-scale data - prediction_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[0], **self.scaler["args"]) - real_value_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[1], **self.scaler["args"]) - # metrics - for metric_name, metric_func in self.metrics.items(): - metric_item = metric_func(prediction_rescaled, real_value_rescaled, null_val=self.null_val) - self.update_epoch_meter("test_"+metric_name, metric_item.item()) diff --git a/baselines/STWave/loss.py b/baselines/STWave/loss.py index eb91d228..35f63f7e 100644 --- a/baselines/STWave/loss.py +++ b/baselines/STWave/loss.py @@ -4,7 +4,7 @@ from basicts.losses import masked_mae -def stwave_masked_mae(preds: list, labels: torch.Tensor, null_val: float = np.nan) -> torch.Tensor: +def stwave_masked_mae(prediction: list, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor: """Masked mean absolute error. Args: @@ -15,7 +15,7 @@ def stwave_masked_mae(preds: list, labels: torch.Tensor, null_val: float = np.na Returns: torch.Tensor: masked mean absolute error """ - lloss = masked_mae(preds[...,1:2], preds[...,2:]) - loss = masked_mae(preds[...,:1], labels) + lloss = masked_mae(prediction[...,1:2], prediction[...,2:], null_val) + loss = masked_mae(prediction[...,:1], target, null_val) return loss + lloss