From 4425c5bc01edbc9542e053ffa09744cedfcad8c5 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 29 Mar 2024 16:11:23 +0800 Subject: [PATCH 01/12] refactor: refactor PatchTST code; --- pypots/imputation/patchtst/model.py | 2 - pypots/imputation/patchtst/modules/core.py | 42 +++++++------------ .../imputation/patchtst/modules/submodules.py | 8 ++-- 3 files changed, 19 insertions(+), 33 deletions(-) diff --git a/pypots/imputation/patchtst/model.py b/pypots/imputation/patchtst/model.py index 5acfb1e4..d0ba98ca 100644 --- a/pypots/imputation/patchtst/model.py +++ b/pypots/imputation/patchtst/model.py @@ -168,7 +168,6 @@ def __init__( # model hype-parameters self.patch_len = patch_len self.stride = stride - self.head_nf = d_model * int((n_steps - patch_len) / stride + 2) self.n_layers = n_layers self.n_heads = n_heads self.d_k = d_k @@ -190,7 +189,6 @@ def __init__( self.d_v, self.patch_len, self.stride, - self.head_nf, self.dropout, self.attn_dropout, ) diff --git a/pypots/imputation/patchtst/modules/core.py b/pypots/imputation/patchtst/modules/core.py index d1a53fa2..1ba4206d 100644 --- a/pypots/imputation/patchtst/modules/core.py +++ b/pypots/imputation/patchtst/modules/core.py @@ -25,31 +25,22 @@ def __init__( d_v: int, patch_len: int, stride: int, - head_nf: int, dropout: float, attn_dropout: float, ): super().__init__() - self.seq_len = n_steps + patch_num = int((n_steps - patch_len) / stride + 2) + head_nf = d_model * patch_num + padding = stride + + self.n_steps = n_steps + self.n_features = n_features self.n_layers = n_layers - padding = stride self.patch_embedding = PatchEmbedding( d_model, patch_len, stride, padding, dropout ) - # self.encoder = Encoder( - # n_layers, - # n_steps, - # n_features, - # d_model, - # d_ffn, - # n_heads, - # d_k, - # d_v, - # dropout, - # attn_dropout, - # ) self.encoder = nn.ModuleList( [ EncoderLayer( @@ -64,30 +55,29 @@ def __init__( for _ in range(n_layers) ] ) - self.head = FlattenHead(n_features, head_nf, n_steps, dropout) - - # for the imputation task, the output dim is the same as input dim - self.projection = nn.Linear(d_model, n_features) + self.head = FlattenHead(head_nf, n_steps, dropout) def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] # do patching and embedding x_enc = X.permute(0, 2, 1) - # u: [bs * nvars x patch_num x d_model] - enc_out, n_vars = self.patch_embedding(x_enc) + # u: [bs * n_features x patch_num x d_model] + enc_out = self.patch_embedding(x_enc) # PatchTST encoder processing - # z: [bs * nvars x patch_num x d_model] + # z: [bs * n_features x patch_num x d_model] for i in range(self.n_layers): enc_out, _ = self.encoder[i](enc_out) - # z: [bs x nvars x patch_num x d_model] - enc_out = enc_out.reshape(-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]) - # z: [bs x nvars x d_model x patch_num] + # z: [bs x n_features x patch_num x d_model] + enc_out = enc_out.reshape( + -1, self.n_features, enc_out.shape[-2], enc_out.shape[-1] + ) + # z: [bs x n_features x d_model x patch_num] enc_out = enc_out.permute(0, 1, 3, 2) # project back the original data space - dec_out = self.head(enc_out) # z: [bs x nvars x target_window] + dec_out = self.head(enc_out) # z: [bs x n_features x target_window] dec_out = dec_out.permute(0, 2, 1) imputed_data = masks * X + (1 - masks) * dec_out diff --git a/pypots/imputation/patchtst/modules/submodules.py b/pypots/imputation/patchtst/modules/submodules.py index 72b1a48b..87576a52 100644 --- a/pypots/imputation/patchtst/modules/submodules.py +++ b/pypots/imputation/patchtst/modules/submodules.py @@ -26,7 +26,6 @@ def __init__(self, d_model, patch_len, stride, padding, dropout): def forward(self, x): # apply patching - n_vars = x.shape[1] x = self.padding_patch_layer(x) x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3]) @@ -34,19 +33,18 @@ def forward(self, x): x = self.value_embedding(x) x = self.position_embedding(x) x = self.dropout(x) - return x, n_vars + return x class FlattenHead(nn.Module): - def __init__(self, n_vars, nf, target_window, head_dropout=0): + def __init__(self, nf, target_window, head_dropout=0): super().__init__() - self.n_vars = n_vars self.flatten = nn.Flatten(start_dim=-2) self.linear = nn.Linear(nf, target_window) self.dropout = nn.Dropout(head_dropout) def forward(self, x): - # x.shape = [batch_size, n_vars, d_model, patch_num] + # x.shape = [batch_size, n_features, d_model, patch_num] x = self.flatten(x) x = self.linear(x) x = self.dropout(x) From cb24887af4f9a3c8b4a9d42d3e68900650d78cec Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 30 Mar 2024 01:02:10 +0800 Subject: [PATCH 02/12] feat: add DLinear as an imputation model; --- pypots/imputation/__init__.py | 2 + pypots/imputation/dlinear/__init__.py | 17 + pypots/imputation/dlinear/data.py | 24 ++ pypots/imputation/dlinear/model.py | 296 ++++++++++++++++++ pypots/imputation/dlinear/modules/__init__.py | 6 + pypots/imputation/dlinear/modules/core.py | 93 ++++++ tests/imputation/dlinear.py | 122 ++++++++ 7 files changed, 560 insertions(+) create mode 100644 pypots/imputation/dlinear/__init__.py create mode 100644 pypots/imputation/dlinear/data.py create mode 100644 pypots/imputation/dlinear/model.py create mode 100644 pypots/imputation/dlinear/modules/__init__.py create mode 100644 pypots/imputation/dlinear/modules/core.py create mode 100644 tests/imputation/dlinear.py diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 64a8a758..2d408d58 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -14,6 +14,7 @@ from .transformer import Transformer from .timesnet import TimesNet from .autoformer import Autoformer +from .dlinear import DLinear from .patchtst import PatchTST from .usgan import USGAN @@ -28,6 +29,7 @@ "Transformer", "TimesNet", "PatchTST", + "DLinear", "Autoformer", "BRITS", "MRNN", diff --git a/pypots/imputation/dlinear/__init__.py b/pypots/imputation/dlinear/__init__.py new file mode 100644 index 00000000..0b179e70 --- /dev/null +++ b/pypots/imputation/dlinear/__init__.py @@ -0,0 +1,17 @@ +""" +The package of the partially-observed time-series imputation model DLinear. + +Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021). +DLinear: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.". + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import DLinear + +__all__ = [ + "DLinear", +] diff --git a/pypots/imputation/dlinear/data.py b/pypots/imputation/dlinear/data.py new file mode 100644 index 00000000..1884054f --- /dev/null +++ b/pypots/imputation/dlinear/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for DLinear. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForDLinear(DatasetForSAITS): + """Actually DLinear uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_labels: bool, + file_type: str = "h5py", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_labels, file_type, rate) diff --git a/pypots/imputation/dlinear/model.py b/pypots/imputation/dlinear/model.py new file mode 100644 index 00000000..33bce716 --- /dev/null +++ b/pypots/imputation/dlinear/model.py @@ -0,0 +1,296 @@ +""" +The implementation of DLinear for the partially-observed time-series imputation task. + +Refer to the paper "Zeng, A., Chen, M., Zhang, L., & Xu, Q. (2023). +Are transformers effective for time series forecasting? AAAI 2023". + +Notes +----- +Partial implementation uses code from https://github.com/thuml/Time-Series-Library + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .data import DatasetForDLinear +from .modules.core import _DLinear +from ..base import BaseNNImputer +from ...data.base import BaseDataset +from ...data.checking import check_X_ori_in_val_set +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.logging import logger + + +class DLinear(BaseNNImputer): + """The PyTorch implementation of the DLinear model. + DLinear is originally proposed by Zeng et al. in :cite:`zeng2023dlinear`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + moving_avg_window_size : + The window size of moving average. + + individual : + Whether to share model across different features. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + References + ---------- + .. [1] `Zeng, Ailing, Muxi Chen, Lei Zhang, and Qiang Xu. + "Are transformers effective for time series forecasting?". + In Proceedings of the AAAI conference on artificial intelligence, vol. 37, no. 9, pp. 11121-11128. 2023. + `_ + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + moving_avg_window_size: int, + individual: bool = False, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.moving_avg_window_size = moving_avg_window_size + self.individual = individual + + # set up the model + self.model = _DLinear( + n_steps, + n_features, + moving_avg_window_size, + individual, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForDLinear( + train_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not check_X_ori_in_val_set(val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForDLinear( + val_set, return_X_ori=True, return_labels=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : str + The type of the given file if test_set is a path string. + + Returns + ------- + result_dict : dict, + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Warnings + -------- + The method impute is deprecated. Please use `predict()` instead. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (time steps), n_features], + Imputed data. + """ + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/dlinear/modules/__init__.py b/pypots/imputation/dlinear/modules/__init__.py new file mode 100644 index 00000000..ceaa7ee3 --- /dev/null +++ b/pypots/imputation/dlinear/modules/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause diff --git a/pypots/imputation/dlinear/modules/core.py b/pypots/imputation/dlinear/modules/core.py new file mode 100644 index 00000000..e8e5ec35 --- /dev/null +++ b/pypots/imputation/dlinear/modules/core.py @@ -0,0 +1,93 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn + +from ...autoformer.modules.submodules import SeriesDecompositionBlock +from ....utils.metrics import calc_mse + + +class _DLinear(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + moving_avg_window_size: int, + individual: bool = False, + ): + super().__init__() + + self.n_steps = n_steps + self.n_features = n_features + self.series_decomp = SeriesDecompositionBlock(moving_avg_window_size) + self.individual = individual + + if individual: + self.Linear_Seasonal = nn.ModuleList() + self.Linear_Trend = nn.ModuleList() + + for i in range(self.n_features): + self.Linear_Seasonal.append(nn.Linear(self.n_steps, self.n_steps)) + self.Linear_Trend.append(nn.Linear(self.n_steps, self.n_steps)) + + self.Linear_Seasonal[i].weight = nn.Parameter( + (1 / self.n_steps) * torch.ones([self.n_steps, self.n_steps]) + ) + self.Linear_Trend[i].weight = nn.Parameter( + (1 / self.n_steps) * torch.ones([self.n_steps, self.n_steps]) + ) + else: + self.Linear_Seasonal = nn.Linear(self.n_steps, self.n_steps) + self.Linear_Trend = nn.Linear(self.n_steps, self.n_steps) + + self.Linear_Seasonal.weight = nn.Parameter( + (1 / self.n_steps) * torch.ones([self.n_steps, self.n_steps]) + ) + self.Linear_Trend.weight = nn.Parameter( + (1 / self.n_steps) * torch.ones([self.n_steps, self.n_steps]) + ) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, masks = inputs["X"], inputs["missing_mask"] + + # DLinear encoder processing + seasonal_init, trend_init = self.series_decomp(X) + seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute( + 0, 2, 1 + ) + if self.individual: + seasonal_output = torch.zeros( + [seasonal_init.size(0), seasonal_init.size(1), self.n_steps], + dtype=seasonal_init.dtype, + ).to(seasonal_init.device) + trend_output = torch.zeros( + [trend_init.size(0), trend_init.size(1), self.n_steps], + dtype=trend_init.dtype, + ).to(trend_init.device) + for i in range(self.n_features): + seasonal_output[:, i, :] = self.Linear_Seasonal[i]( + seasonal_init[:, i, :] + ) + trend_output[:, i, :] = self.Linear_Trend[i](trend_init[:, i, :]) + else: + seasonal_output = self.Linear_Seasonal(seasonal_init) + trend_output = self.Linear_Trend(trend_init) + output = seasonal_output + trend_output + output = output.permute(0, 2, 1) + + imputed_data = masks * X + (1 - masks) * output + results = { + "imputed_data": imputed_data, + } + + if training: + # `loss` is always the item for backward propagating to update the model + loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) + results["loss"] = loss + + return results diff --git a/tests/imputation/dlinear.py b/tests/imputation/dlinear.py new file mode 100644 index 00000000..e2680b23 --- /dev/null +++ b/tests/imputation/dlinear.py @@ -0,0 +1,122 @@ +""" +Test cases for DLinear imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import DLinear +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + H5_TRAIN_SET_PATH, + H5_VAL_SET_PATH, + H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestDLinear(unittest.TestCase): + logger.info("Running tests for an imputation model DLinear...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "DLinear") + model_save_name = "saved_dlinear_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a DLinear model + dlinear = DLinear( + DATA["n_steps"], + DATA["n_features"], + moving_avg_window_size=3, + individual=False, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-dlinear") + def test_0_fit(self): + self.dlinear.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-dlinear") + def test_1_impute(self): + imputation_results = self.dlinear.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"DLinear test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-dlinear") + def test_2_parameters(self): + assert hasattr(self.dlinear, "model") and self.dlinear.model is not None + + assert hasattr(self.dlinear, "optimizer") and self.dlinear.optimizer is not None + + assert hasattr(self.dlinear, "best_loss") + self.assertNotEqual(self.dlinear.best_loss, float("inf")) + + assert ( + hasattr(self.dlinear, "best_model_dict") + and self.dlinear.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-dlinear") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.dlinear) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.dlinear.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.dlinear.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-dlinear") + def test_4_lazy_loading(self): + self.dlinear.fit(H5_TRAIN_SET_PATH, H5_VAL_SET_PATH) + imputation_results = self.dlinear.predict(H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading DLinear test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main() From a812c42235c4a6a85288745005050cd6449f7fdd Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 30 Mar 2024 10:27:39 +0800 Subject: [PATCH 03/12] docs: add references for Autoformer, PatchTST, and TimesNet; --- pypots/imputation/autoformer/model.py | 11 +++++------ pypots/imputation/patchtst/model.py | 11 +++++------ pypots/imputation/timesnet/model.py | 11 +++++------ 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/pypots/imputation/autoformer/model.py b/pypots/imputation/autoformer/model.py index 350277a2..b6e6c96d 100644 --- a/pypots/imputation/autoformer/model.py +++ b/pypots/imputation/autoformer/model.py @@ -101,13 +101,12 @@ class Autoformer(BaseNNImputer): better than in previous epochs. The "all" strategy will save every model after each epoch training. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying Transformer model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Wu, Haixu, Jiehui Xu, Jianmin Wang, and Mingsheng Long. + "Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting". + Advances in neural information processing systems 34 (2021): 22419-22430. + `_ """ diff --git a/pypots/imputation/patchtst/model.py b/pypots/imputation/patchtst/model.py index d0ba98ca..fbd6567e 100644 --- a/pypots/imputation/patchtst/model.py +++ b/pypots/imputation/patchtst/model.py @@ -111,13 +111,12 @@ class PatchTST(BaseNNImputer): better than in previous epochs. The "all" strategy will save every model after each epoch training. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying Transformer model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Nie, Yuqi, Nam H. Nguyen, Phanwadee Sinthong, and Jayant Kalagnanam. + "A time series is worth 64 words: Long-term forecasting with transformers". + ICLR 2023. + `_ """ diff --git a/pypots/imputation/timesnet/model.py b/pypots/imputation/timesnet/model.py index 9e93d2f9..8b648d2d 100644 --- a/pypots/imputation/timesnet/model.py +++ b/pypots/imputation/timesnet/model.py @@ -103,13 +103,12 @@ class TimesNet(BaseNNImputer): better than in previous epochs. The "all" strategy will save every model after each epoch training. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying Transformer model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Wu, Haixu, Tengge Hu, Yong Liu, Hang Zhou, Jianmin Wang, and Mingsheng Long. + "TimesNet: Temporal 2d-variation modeling for general time series analysis". + ICLR 2022. + `_ """ From 5f02bded31cbb83701dfe9a639d1cdb8194bc00d Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 30 Mar 2024 11:23:08 +0800 Subject: [PATCH 04/12] feat: add ETSformer as an imputation model; --- pypots/imputation/__init__.py | 2 + pypots/imputation/etsformer/__init__.py | 17 + pypots/imputation/etsformer/data.py | 24 ++ pypots/imputation/etsformer/model.py | 324 ++++++++++++++++ .../imputation/etsformer/modules/__init__.py | 6 + pypots/imputation/etsformer/modules/core.py | 101 +++++ .../etsformer/modules/submodules.py | 354 ++++++++++++++++++ tests/imputation/etsformer.py | 130 +++++++ 8 files changed, 958 insertions(+) create mode 100644 pypots/imputation/etsformer/__init__.py create mode 100644 pypots/imputation/etsformer/data.py create mode 100644 pypots/imputation/etsformer/model.py create mode 100644 pypots/imputation/etsformer/modules/__init__.py create mode 100644 pypots/imputation/etsformer/modules/core.py create mode 100644 pypots/imputation/etsformer/modules/submodules.py create mode 100644 tests/imputation/etsformer.py diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 2d408d58..f1c4d381 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -13,6 +13,7 @@ from .saits import SAITS from .transformer import Transformer from .timesnet import TimesNet +from .etsformer import ETSformer from .autoformer import Autoformer from .dlinear import DLinear from .patchtst import PatchTST @@ -27,6 +28,7 @@ # neural network imputation methods "SAITS", "Transformer", + "ETSformer", "TimesNet", "PatchTST", "DLinear", diff --git a/pypots/imputation/etsformer/__init__.py b/pypots/imputation/etsformer/__init__.py new file mode 100644 index 00000000..1e5c8417 --- /dev/null +++ b/pypots/imputation/etsformer/__init__.py @@ -0,0 +1,17 @@ +""" +The package of the partially-observed time-series imputation model ETSformer. + +Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021). +ETSformer: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.". + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import ETSformer + +__all__ = [ + "ETSformer", +] diff --git a/pypots/imputation/etsformer/data.py b/pypots/imputation/etsformer/data.py new file mode 100644 index 00000000..f03a4e61 --- /dev/null +++ b/pypots/imputation/etsformer/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for ETSformer. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForETSformer(DatasetForSAITS): + """Actually ETSformer uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_labels: bool, + file_type: str = "h5py", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_labels, file_type, rate) diff --git a/pypots/imputation/etsformer/model.py b/pypots/imputation/etsformer/model.py new file mode 100644 index 00000000..50281ba2 --- /dev/null +++ b/pypots/imputation/etsformer/model.py @@ -0,0 +1,324 @@ +""" +The implementation of ETSformer for the partially-observed time-series imputation task. + +Refer to the paper "Woo, G., Liu, C., Sahoo, D., Kumar, A., & Hoi, S. (2023). +ETSformer: Exponential Smoothing Transformers for Time-series Forecasting. ICLR 2023.". + +Notes +----- +Partial implementation uses code from https://github.com/salesforce/ETSformer + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .data import DatasetForETSformer +from .modules.core import _ETSformer +from ..base import BaseNNImputer +from ...data.base import BaseDataset +from ...data.checking import check_X_ori_in_val_set +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.logging import logger + + +class ETSformer(BaseNNImputer): + """The PyTorch implementation of the ETSformer model. + ETSformer is originally proposed by Woo et al. in :cite:`woo2023etsformer`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + n_e_layers : + The number of layers in the ETSformer encoder. + + n_d_layers : + The number of layers in the ETSformer decoder. + + n_heads : + The number of heads in each layer of ETSformer. + + d_model : + The dimension of the model. + + d_ffn : + The dimension of the feed-forward network. + + + dropout : + The dropout rate for the model. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + References + ---------- + .. [1] `Woo, Gerald, Chenghao Liu, Doyen Sahoo, Akshat Kumar, and Steven Hoi. + "ETSformer: Exponential Smoothing Transformers for Time-series Forecasting ". + ICLR 2023. + `_ + + """ + + def __init__( + self, + n_steps, + n_features, + n_e_layers, + n_d_layers, + n_heads, + d_model, + d_ffn, + top_k, + dropout: float = 0, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.n_heads = n_heads + self.n_e_layers = n_e_layers + self.n_d_layers = n_d_layers + self.d_model = d_model + self.d_ffn = d_ffn + self.dropout = dropout + self.top_k = top_k + + # set up the model + self.model = _ETSformer( + self.n_steps, + self.n_features, + self.n_e_layers, + self.n_d_layers, + self.n_heads, + self.d_model, + self.d_ffn, + self.dropout, + self.top_k, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForETSformer( + train_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not check_X_ori_in_val_set(val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForETSformer( + val_set, return_X_ori=True, return_labels=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : str + The type of the given file if test_set is a path string. + + Returns + ------- + result_dict : dict, + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Warnings + -------- + The method impute is deprecated. Please use `predict()` instead. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (time steps), n_features], + Imputed data. + """ + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/etsformer/modules/__init__.py b/pypots/imputation/etsformer/modules/__init__.py new file mode 100644 index 00000000..ceaa7ee3 --- /dev/null +++ b/pypots/imputation/etsformer/modules/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause diff --git a/pypots/imputation/etsformer/modules/core.py b/pypots/imputation/etsformer/modules/core.py new file mode 100644 index 00000000..13f692fc --- /dev/null +++ b/pypots/imputation/etsformer/modules/core.py @@ -0,0 +1,101 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from .submodules import ( + Transform, + ETSformerEncoderLayer, + ETSformerEncoder, + ETSformerDecoderLayer, + ETSformerDecoder, +) +from ...timesnet.modules.embedding import DataEmbedding +from ....utils.metrics import calc_mse + + +class _ETSformer(nn.Module): + def __init__( + self, + n_steps, + n_features, + n_e_layers, + n_d_layers, + n_heads, + d_model, + d_ffn, + dropout, + top_k, + activation="sigmoid", + ): + super().__init__() + + self.n_steps = n_steps + + self.enc_embedding = DataEmbedding( + n_features, + d_model, + dropout=dropout, + ) + + # Encoder + self.encoder = ETSformerEncoder( + [ + ETSformerEncoderLayer( + d_model, + n_heads, + n_features, + n_steps, + n_steps, + top_k, + dim_feedforward=d_ffn, + dropout=dropout, + activation=activation, + ) + for _ in range(n_e_layers) + ] + ) + # Decoder + self.decoder = ETSformerDecoder( + [ + ETSformerDecoderLayer( + d_model, + n_heads, + n_features, + n_steps, + dropout=dropout, + ) + for _ in range(n_d_layers) + ], + ) + self.transform = Transform(sigma=0.2) + + # for the imputation task, the output dim is the same as input dim + self.projection = nn.Linear(d_model, n_features) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, masks = inputs["X"], inputs["missing_mask"] + + # embedding + res = self.enc_embedding(X) + + # ETSformer encoder processing + level, growths, seasons = self.encoder(res, X, attn_mask=None) + growth, season = self.decoder(growths, seasons) + output = level[:, -1:] + growth + season + + imputed_data = masks * X + (1 - masks) * output + results = { + "imputed_data": imputed_data, + } + + if training: + # `loss` is always the item for backward propagating to update the model + loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) + results["loss"] = loss + + return results diff --git a/pypots/imputation/etsformer/modules/submodules.py b/pypots/imputation/etsformer/modules/submodules.py new file mode 100644 index 00000000..d1a1c7bb --- /dev/null +++ b/pypots/imputation/etsformer/modules/submodules.py @@ -0,0 +1,354 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import math + +import torch +import torch.fft as fft +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, reduce, repeat +from scipy.fftpack import next_fast_len + + +class Transform: + def __init__(self, sigma): + self.sigma = sigma + + @torch.no_grad() + def transform(self, x): + return self.jitter(self.shift(self.scale(x))) + + def jitter(self, x): + return x + (torch.randn(x.shape).to(x.device) * self.sigma) + + def scale(self, x): + return x * (torch.randn(x.size(-1)).to(x.device) * self.sigma + 1) + + def shift(self, x): + return x + (torch.randn(x.size(-1)).to(x.device) * self.sigma) + + +def conv1d_fft(f, g, dim=-1): + N = f.size(dim) + M = g.size(dim) + + fast_len = next_fast_len(N + M - 1) + + F_f = fft.rfft(f, fast_len, dim=dim) + F_g = fft.rfft(g, fast_len, dim=dim) + + F_fg = F_f * F_g.conj() + out = fft.irfft(F_fg, fast_len, dim=dim) + out = out.roll((-1,), dims=(dim,)) + idx = torch.as_tensor(range(fast_len - N, fast_len)).to(out.device) + out = out.index_select(dim, idx) + + return out + + +class ExponentialSmoothing(nn.Module): + def __init__(self, dim, nhead, dropout=0.1, aux=False): + super().__init__() + self._smoothing_weight = nn.Parameter(torch.randn(nhead, 1)) + self.v0 = nn.Parameter(torch.randn(1, 1, nhead, dim)) + self.dropout = nn.Dropout(dropout) + if aux: + self.aux_dropout = nn.Dropout(dropout) + + def forward(self, values, aux_values=None): + b, t, h, d = values.shape + + init_weight, weight = self.get_exponential_weight(t) + output = conv1d_fft(self.dropout(values), weight, dim=1) + output = init_weight * self.v0 + output + + if aux_values is not None: + aux_weight = weight / (1 - self.weight) * self.weight + aux_output = conv1d_fft(self.aux_dropout(aux_values), aux_weight) + output = output + aux_output + + return output + + def get_exponential_weight(self, T): + # Generate array [0, 1, ..., T-1] + powers = torch.arange(T, dtype=torch.float, device=self.weight.device) + + # (1 - \alpha) * \alpha^t, for all t = T-1, T-2, ..., 0] + weight = (1 - self.weight) * (self.weight ** torch.flip(powers, dims=(0,))) + + # \alpha^t for all t = 1, 2, ..., T + init_weight = self.weight ** (powers + 1) + + return rearrange(init_weight, "h t -> 1 t h 1"), rearrange( + weight, "h t -> 1 t h 1" + ) + + @property + def weight(self): + return torch.sigmoid(self._smoothing_weight) + + +class Feedforward(nn.Module): + def __init__(self, d_model, dim_feedforward, dropout=0.1, activation="sigmoid"): + # Implementation of Feedforward model + super().__init__() + self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) + self.dropout1 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) + self.dropout2 = nn.Dropout(dropout) + self.activation = getattr(F, activation) + + def forward(self, x): + x = self.linear2(self.dropout1(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class GrowthLayer(nn.Module): + def __init__(self, d_model, nhead, d_head=None, dropout=0.1): + super().__init__() + self.d_head = d_head or (d_model // nhead) + self.d_model = d_model + self.nhead = nhead + + self.z0 = nn.Parameter(torch.randn(self.nhead, self.d_head)) + self.in_proj = nn.Linear(self.d_model, self.d_head * self.nhead) + self.es = ExponentialSmoothing(self.d_head, self.nhead, dropout=dropout) + self.out_proj = nn.Linear(self.d_head * self.nhead, self.d_model) + + assert ( + self.d_head * self.nhead == self.d_model + ), "d_model must be divisible by nhead" + + def forward(self, inputs): + """ + :param inputs: shape: (batch, seq_len, dim) + :return: shape: (batch, seq_len, dim) + """ + b, t, d = inputs.shape + values = self.in_proj(inputs).view(b, t, self.nhead, -1) + values = torch.cat([repeat(self.z0, "h d -> b 1 h d", b=b), values], dim=1) + values = values[:, 1:] - values[:, :-1] + out = self.es(values) + out = torch.cat([repeat(self.es.v0, "1 1 h d -> b 1 h d", b=b), out], dim=1) + out = rearrange(out, "b t h d -> b t (h d)") + return self.out_proj(out) + + +class FourierLayer(nn.Module): + def __init__(self, d_model, pred_len, k=None, low_freq=1): + super().__init__() + self.d_model = d_model + self.pred_len = pred_len + self.k = k + self.low_freq = low_freq + + def forward(self, x): + """x: (b, t, d)""" + b, t, d = x.shape + x_freq = fft.rfft(x, dim=1) + + if t % 2 == 0: + x_freq = x_freq[:, self.low_freq : -1] + f = fft.rfftfreq(t)[self.low_freq : -1] + else: + x_freq = x_freq[:, self.low_freq :] + f = fft.rfftfreq(t)[self.low_freq :] + + x_freq, index_tuple = self.topk_freq(x_freq) + f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)) + f = rearrange(f[index_tuple], "b f d -> b f () d").to(x_freq.device) + + return self.extrapolate(x_freq, f, t) + + def extrapolate(self, x_freq, f, t): + x_freq = torch.cat([x_freq, x_freq.conj()], dim=1) + f = torch.cat([f, -f], dim=1) + t_val = rearrange( + torch.arange(t + self.pred_len, dtype=torch.float), "t -> () () t ()" + ).to(x_freq.device) + + amp = rearrange(x_freq.abs() / t, "b f d -> b f () d") + phase = rearrange(x_freq.angle(), "b f d -> b f () d") + + x_time = amp * torch.cos(2 * math.pi * f * t_val + phase) + + return reduce(x_time, "b f t d -> b t d", "sum") + + def topk_freq(self, x_freq): + values, indices = torch.topk( + x_freq.abs(), self.k, dim=1, largest=True, sorted=True + ) + mesh_a, mesh_b = torch.meshgrid( + torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2)) + ) + index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1)) + x_freq = x_freq[index_tuple] + + return x_freq, index_tuple + + +class LevelLayer(nn.Module): + def __init__(self, d_model, c_out, dropout=0.1): + super().__init__() + self.d_model = d_model + self.c_out = c_out + + self.es = ExponentialSmoothing(1, self.c_out, dropout=dropout, aux=True) + self.growth_pred = nn.Linear(self.d_model, self.c_out) + self.season_pred = nn.Linear(self.d_model, self.c_out) + + def forward(self, level, growth, season): + b, t, _ = level.shape + growth = self.growth_pred(growth).view(b, t, self.c_out, 1) + season = self.season_pred(season).view(b, t, self.c_out, 1) + growth = growth.view(b, t, self.c_out, 1) + season = season.view(b, t, self.c_out, 1) + level = level.view(b, t, self.c_out, 1) + out = self.es(level - season, aux_values=growth) + out = rearrange(out, "b t h d -> b t (h d)") + return out + + +class ETSformerEncoderLayer(nn.Module): + def __init__( + self, + d_model, + n_heads, + c_out, + seq_len, + pred_len, + k, + dim_feedforward=None, + dropout=0.1, + activation="sigmoid", + layer_norm_eps=1e-5, + ): + super().__init__() + self.d_model = d_model + self.nhead = n_heads + self.c_out = c_out + self.seq_len = seq_len + self.pred_len = pred_len + dim_feedforward = dim_feedforward or 4 * d_model + self.dim_feedforward = dim_feedforward + + self.growth_layer = GrowthLayer(d_model, n_heads, dropout=dropout) + self.seasonal_layer = FourierLayer(d_model, pred_len, k=k) + self.level_layer = LevelLayer(d_model, c_out, dropout=dropout) + + # Implementation of Feedforward model + self.ff = Feedforward( + d_model, dim_feedforward, dropout=dropout, activation=activation + ) + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + def forward(self, res, level, attn_mask=None): + season = self._season_block(res) + res = res - season[:, : -self.pred_len] + growth = self._growth_block(res) + res = self.norm1(res - growth[:, 1:]) + res = self.norm2(res + self.ff(res)) + + level = self.level_layer(level, growth[:, :-1], season[:, : -self.pred_len]) + return res, level, growth, season + + def _growth_block(self, x): + x = self.growth_layer(x) + return self.dropout1(x) + + def _season_block(self, x): + x = self.seasonal_layer(x) + return self.dropout2(x) + + +class ETSformerEncoder(nn.Module): + def __init__(self, layers): + super().__init__() + self.layers = nn.ModuleList(layers) + + def forward(self, res, level, attn_mask=None): + growths = [] + seasons = [] + for layer in self.layers: + res, level, growth, season = layer(res, level, attn_mask=None) + growths.append(growth) + seasons.append(season) + + return level, growths, seasons + + +class DampingLayer(nn.Module): + def __init__(self, pred_len, nhead, dropout=0.1): + super().__init__() + self.pred_len = pred_len + self.nhead = nhead + self._damping_factor = nn.Parameter(torch.randn(1, nhead)) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = repeat(x, "b 1 d -> b t d", t=self.pred_len) + b, t, d = x.shape + + powers = torch.arange(self.pred_len).to(self._damping_factor.device) + 1 + powers = powers.view(self.pred_len, 1) + damping_factors = self.damping_factor**powers + damping_factors = damping_factors.cumsum(dim=0) + x = x.view(b, t, self.nhead, -1) + x = self.dropout(x) * damping_factors.unsqueeze(-1) + return x.view(b, t, d) + + @property + def damping_factor(self): + return torch.sigmoid(self._damping_factor) + + +class ETSformerDecoderLayer(nn.Module): + def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.c_out = c_out + self.pred_len = pred_len + + self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout) + self.dropout1 = nn.Dropout(dropout) + + def forward(self, growth, season): + growth_horizon = self.growth_damping(growth[:, -1:]) + growth_horizon = self.dropout1(growth_horizon) + + seasonal_horizon = season[:, -self.pred_len :] + return growth_horizon, seasonal_horizon + + +class ETSformerDecoder(nn.Module): + def __init__(self, layers): + super().__init__() + self.d_model = layers[0].d_model + self.c_out = layers[0].c_out + self.pred_len = layers[0].pred_len + self.nhead = layers[0].nhead + + self.layers = nn.ModuleList(layers) + self.pred = nn.Linear(self.d_model, self.c_out) + + def forward(self, growths, seasons): + growth_repr = [] + season_repr = [] + + for idx, layer in enumerate(self.layers): + growth_horizon, season_horizon = layer(growths[idx], seasons[idx]) + growth_repr.append(growth_horizon) + season_repr.append(season_horizon) + growth_repr = sum(growth_repr) + season_repr = sum(season_repr) + return self.pred(growth_repr), self.pred(season_repr) diff --git a/tests/imputation/etsformer.py b/tests/imputation/etsformer.py new file mode 100644 index 00000000..c098b79f --- /dev/null +++ b/tests/imputation/etsformer.py @@ -0,0 +1,130 @@ +""" +Test cases for ETSformer imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import ETSformer +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + H5_TRAIN_SET_PATH, + H5_VAL_SET_PATH, + H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestETSformer(unittest.TestCase): + logger.info("Running tests for an imputation model ETSformer...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "ETSformer") + model_save_name = "saved_etsformer_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a ETSformer model + etsformer = ETSformer( + DATA["n_steps"], + DATA["n_features"], + n_e_layers=2, + n_d_layers=2, + n_heads=2, + d_model=128, + d_ffn=256, + top_k=3, + dropout=0, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-etsformer") + def test_0_fit(self): + self.etsformer.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-etsformer") + def test_1_impute(self): + imputation_results = self.etsformer.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"ETSformer test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-etsformer") + def test_2_parameters(self): + assert hasattr(self.etsformer, "model") and self.etsformer.model is not None + + assert ( + hasattr(self.etsformer, "optimizer") + and self.etsformer.optimizer is not None + ) + + assert hasattr(self.etsformer, "best_loss") + self.assertNotEqual(self.etsformer.best_loss, float("inf")) + + assert ( + hasattr(self.etsformer, "best_model_dict") + and self.etsformer.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-etsformer") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.etsformer) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.etsformer.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.etsformer.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-etsformer") + def test_4_lazy_loading(self): + self.etsformer.fit(H5_TRAIN_SET_PATH, H5_VAL_SET_PATH) + imputation_results = self.etsformer.predict(H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading ETSformer test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main() From 5e98d3b2f8cb6ebf7da8b017c51f689d85f9afac Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 30 Mar 2024 13:56:56 +0800 Subject: [PATCH 05/12] refactor: remove unused modules in ETSformer; --- pypots/imputation/etsformer/model.py | 2 ++ pypots/imputation/etsformer/modules/core.py | 6 +----- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pypots/imputation/etsformer/model.py b/pypots/imputation/etsformer/model.py index 50281ba2..2e50fb79 100644 --- a/pypots/imputation/etsformer/model.py +++ b/pypots/imputation/etsformer/model.py @@ -56,6 +56,8 @@ class ETSformer(BaseNNImputer): d_ffn : The dimension of the feed-forward network. + top_k : + Top-K Fourier bases. dropout : The dropout rate for the model. diff --git a/pypots/imputation/etsformer/modules/core.py b/pypots/imputation/etsformer/modules/core.py index 13f692fc..81009dd2 100644 --- a/pypots/imputation/etsformer/modules/core.py +++ b/pypots/imputation/etsformer/modules/core.py @@ -8,7 +8,6 @@ import torch.nn as nn from .submodules import ( - Transform, ETSformerEncoderLayer, ETSformerEncoder, ETSformerDecoderLayer, @@ -72,10 +71,7 @@ def __init__( for _ in range(n_d_layers) ], ) - self.transform = Transform(sigma=0.2) - - # for the imputation task, the output dim is the same as input dim - self.projection = nn.Linear(d_model, n_features) + # self.transform = Transform(sigma=0.2) # for forecasting def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] From b26d4074c9d559637386bd4ac8188dd6bfd1ef61 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 30 Mar 2024 13:59:48 +0800 Subject: [PATCH 06/12] depen: add einops as an dependency; --- environment-dev.yml | 1 + requirements.txt | 1 + setup.cfg | 3 ++- setup.py | 1 + tests/environment_for_conda_test.yml | 3 ++- 5 files changed, 7 insertions(+), 2 deletions(-) diff --git a/environment-dev.yml b/environment-dev.yml index 2b3bea45..d3a47627 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -13,6 +13,7 @@ dependencies: - conda-forge::numpy - conda-forge::scipy - conda-forge::python + - conda-forge::einops - conda-forge::pandas - conda-forge::matplotlib - conda-forge::tensorboard diff --git a/requirements.txt b/requirements.txt index 13a82508..cfbab72f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ h5py numpy scipy +einops pandas matplotlib tensorboard diff --git a/setup.cfg b/setup.cfg index bbc3f761..a36169c0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,7 +28,8 @@ basic = numpy scikit-learn matplotlib - pandas<2.0.0 + einops + pandas torch>=1.10.0 tensorboard scipy diff --git a/setup.py b/setup.py index 4630e69c..92011eeb 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ "h5py", "numpy", "scipy", + "einops", "pandas", "matplotlib", "tensorboard", diff --git a/tests/environment_for_conda_test.yml b/tests/environment_for_conda_test.yml index 2cb128d9..54f07925 100644 --- a/tests/environment_for_conda_test.yml +++ b/tests/environment_for_conda_test.yml @@ -13,7 +13,8 @@ dependencies: - conda-forge::scipy - conda-forge::numpy - conda-forge::scikit-learn - - conda-forge::pandas <2.0.0 + - conda-forge::einops + - conda-forge::pandas - conda-forge::h5py - conda-forge::tensorboard - conda-forge::pygrinder >=0.4 From 0d577acdfec822d711171e1f5fe4f88f20750fa4 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 30 Mar 2024 16:15:59 +0800 Subject: [PATCH 07/12] feat: add Crossformer as an imputation model; --- pypots/imputation/__init__.py | 2 + pypots/imputation/crossformer/__init__.py | 17 + pypots/imputation/crossformer/data.py | 24 ++ pypots/imputation/crossformer/model.py | 326 ++++++++++++++++++ .../crossformer/modules/__init__.py | 6 + pypots/imputation/crossformer/modules/core.py | 101 ++++++ .../crossformer/modules/submodules.py | 250 ++++++++++++++ tests/imputation/crossformer.py | 131 +++++++ 8 files changed, 857 insertions(+) create mode 100644 pypots/imputation/crossformer/__init__.py create mode 100644 pypots/imputation/crossformer/data.py create mode 100644 pypots/imputation/crossformer/model.py create mode 100644 pypots/imputation/crossformer/modules/__init__.py create mode 100644 pypots/imputation/crossformer/modules/core.py create mode 100644 pypots/imputation/crossformer/modules/submodules.py create mode 100644 tests/imputation/crossformer.py diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index f1c4d381..0f41b3f7 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -14,6 +14,7 @@ from .transformer import Transformer from .timesnet import TimesNet from .etsformer import ETSformer +from .crossformer import Crossformer from .autoformer import Autoformer from .dlinear import DLinear from .patchtst import PatchTST @@ -29,6 +30,7 @@ "SAITS", "Transformer", "ETSformer", + "Crossformer", "TimesNet", "PatchTST", "DLinear", diff --git a/pypots/imputation/crossformer/__init__.py b/pypots/imputation/crossformer/__init__.py new file mode 100644 index 00000000..c289a487 --- /dev/null +++ b/pypots/imputation/crossformer/__init__.py @@ -0,0 +1,17 @@ +""" +The package of the partially-observed time-series imputation model Transformer. + +Refer to the paper "Du, W., Cote, D., & Liu, Y. (2023). SAITS: Self-Attention-based Imputation for Time Series. +Expert systems with applications." + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import Crossformer + +__all__ = [ + "Crossformer", +] diff --git a/pypots/imputation/crossformer/data.py b/pypots/imputation/crossformer/data.py new file mode 100644 index 00000000..056486f8 --- /dev/null +++ b/pypots/imputation/crossformer/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for Crossformer. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForCrossformer(DatasetForSAITS): + """Actually Crossformer uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_labels: bool, + file_type: str = "h5py", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_labels, file_type, rate) diff --git a/pypots/imputation/crossformer/model.py b/pypots/imputation/crossformer/model.py new file mode 100644 index 00000000..89ea4c27 --- /dev/null +++ b/pypots/imputation/crossformer/model.py @@ -0,0 +1,326 @@ +""" +The implementation of Crossformer for the partially-observed time-series imputation task. + +Refer to the paper "Zhang, Y., & Yan, J. (2023). +Crossformer: Transformer utilizing cross-dimension dependency for multivariate time series forecasting. ICLR 2023" + +Notes +----- +Partial implementation uses code from +https://github.com/Thinklab-SJTU/Crossformer and https://github.com/thuml/Time-Series-Library + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .data import DatasetForCrossformer +from .modules.core import _Crossformer +from ..base import BaseNNImputer +from ...data.base import BaseDataset +from ...data.checking import check_X_ori_in_val_set +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.logging import logger + + +class Crossformer(BaseNNImputer): + """The PyTorch implementation of the Crossformer model. + Crossformer is originally proposed by Zhang et al. in :cite:`zhang2023crossformer`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + n_layers : + The number of layers in the 1st and 2nd DMSA blocks in the SAITS model. + + n_heads: + The number of heads in the multi-head attention mechanism. + + d_model : + The dimension of the model. + + d_ffn : + The dimension of the feed-forward network. + + factor : + The num of routers in Cross-Dimension Stage of TSA (c). + + seg_len : + The length of the segment in the model. + + win_size : + The window size for merging segment. + + dropout : + The dropout rate for the model. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + n_layers: int, + n_heads: int, + d_model: int, + d_ffn: int, + factor: int, + seg_len: int, + win_size: int, + dropout: float = 0, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.n_layers = n_layers + self.n_heads = n_heads + self.d_model = d_model + self.d_ffn = d_ffn + self.factor = factor + self.seg_len = seg_len + self.win_size = win_size + self.dropout = dropout + + # set up the model + self.model = _Crossformer( + self.n_steps, + self.n_features, + self.n_layers, + self.n_heads, + self.d_model, + self.d_ffn, + self.factor, + self.seg_len, + self.win_size, + self.dropout, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForCrossformer( + train_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not check_X_ori_in_val_set(val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForCrossformer( + val_set, return_X_ori=True, return_labels=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : str + The type of the given file if test_set is a path string. + + Returns + ------- + result_dict : dict, + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Warnings + -------- + The method impute is deprecated. Please use `predict()` instead. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (time steps), n_features], + Imputed data. + """ + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/crossformer/modules/__init__.py b/pypots/imputation/crossformer/modules/__init__.py new file mode 100644 index 00000000..ceaa7ee3 --- /dev/null +++ b/pypots/imputation/crossformer/modules/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause diff --git a/pypots/imputation/crossformer/modules/core.py b/pypots/imputation/crossformer/modules/core.py new file mode 100644 index 00000000..0cc9b07a --- /dev/null +++ b/pypots/imputation/crossformer/modules/core.py @@ -0,0 +1,101 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from math import ceil + +import torch +import torch.nn as nn +from einops import rearrange + +from .submodules import CrossformerEncoder, ScaleBlock +from ...patchtst.modules.submodules import FlattenHead, PatchEmbedding +from ....utils.metrics import calc_mse + + +class _Crossformer(nn.Module): + def __init__( + self, + n_steps, + n_features, + n_layers, + n_heads, + d_model, + d_ffn, + factor, + seg_len, + win_size, + dropout, + ): + super().__init__() + + self.n_features = n_features + + # The padding operation to handle invisible sgemnet length + pad_in_len = ceil(1.0 * n_steps / seg_len) * seg_len + in_seg_num = pad_in_len // seg_len + out_seg_num = ceil(in_seg_num / (win_size ** (n_layers - 1))) + head_nf = d_model * out_seg_num + + # Embedding + self.enc_value_embedding = PatchEmbedding( + d_model, + seg_len, + seg_len, + pad_in_len - n_steps, + 0, + ) + self.enc_pos_embedding = nn.Parameter( + torch.randn(1, n_features, in_seg_num, d_model) + ) + self.pre_norm = nn.LayerNorm(d_model) + + # Encoder + self.encoder = CrossformerEncoder( + [ + ScaleBlock( + 1 if layer == 0 else win_size, + d_model, + n_heads, + d_ffn, + 1, + dropout, + in_seg_num if layer == 0 else ceil(in_seg_num / win_size**layer), + factor, + ) + for layer in range(n_layers) + ] + ) + + self.head = FlattenHead(head_nf, n_steps, dropout) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, masks = inputs["X"], inputs["missing_mask"] + + # embedding + x_enc = self.enc_value_embedding(X.permute(0, 2, 1)) + + # Crossformer processing + x_enc = rearrange( + x_enc, "(b d) seg_num d_model -> b d seg_num d_model", d=self.n_features + ) + x_enc += self.enc_pos_embedding + x_enc = self.pre_norm(x_enc) + enc_out, attns = self.encoder(x_enc) + # project back the original data space + dec_out = self.head(enc_out[-1].permute(0, 1, 3, 2)).permute(0, 2, 1) + + imputed_data = masks * X + (1 - masks) * dec_out + results = { + "imputed_data": imputed_data, + } + + if training: + # `loss` is always the item for backward propagating to update the model + loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"]) + results["loss"] = loss + + return results diff --git a/pypots/imputation/crossformer/modules/submodules.py b/pypots/imputation/crossformer/modules/submodules.py new file mode 100644 index 00000000..0df19b81 --- /dev/null +++ b/pypots/imputation/crossformer/modules/submodules.py @@ -0,0 +1,250 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from ....nn.modules.transformer import MultiHeadAttention + + +class TwoStageAttentionLayer(nn.Module): + """ + The Two Stage Attention (TSA) Layer + input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model] + """ + + def __init__( + self, + seg_num, + factor, + d_model, + n_heads, + d_k, + d_v, + d_ff=None, + dropout=0.1, + attn_dropout=0.1, + ): + super().__init__() + d_ff = 4 * d_model if d_ff is None else d_ff + self.time_attention = MultiHeadAttention( + n_heads, d_model, d_k, d_v, attn_dropout + ) + self.dim_sender = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) + self.dim_receiver = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) + self.router = nn.Parameter(torch.randn(seg_num, factor, d_model)) + + self.dropout = nn.Dropout(dropout) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.norm4 = nn.LayerNorm(d_model) + + self.MLP1 = nn.Sequential( + nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model) + ) + self.MLP2 = nn.Sequential( + nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model) + ) + + def forward(self, x): + # Cross Time Stage: Directly apply MSA to each dimension + batch, ts_d, seg_num, d_model = x.shape + time_in = rearrange(x, "b ts_d seg_num d_model -> (b ts_d) seg_num d_model") + # time_in = x.reshape(-1, seg_num, d_model) + time_enc, attn = self.time_attention(time_in, time_in, time_in, attn_mask=None) + dim_in = time_in + self.dropout(time_enc) + dim_in = self.norm1(dim_in) + dim_in = dim_in + self.dropout(self.MLP1(dim_in)) + dim_in = self.norm2(dim_in) + + # Cross dimension stage: use a small set of learnable vectors to + # aggregate and distribute messages to build the D-to-D connection + dim_send = rearrange( + dim_in, "(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model", b=batch + ) + # dim_send = dim_in.reshape() + batch_router = repeat( + self.router, + "seg_num factor d_model -> (repeat seg_num) factor d_model", + repeat=batch, + ) + dim_buffer, attn = self.dim_sender( + batch_router, dim_send, dim_send, attn_mask=None + ) + dim_receive, attn = self.dim_receiver( + dim_send, dim_buffer, dim_buffer, attn_mask=None + ) + dim_enc = dim_send + self.dropout(dim_receive) + dim_enc = self.norm3(dim_enc) + dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc)) + dim_enc = self.norm4(dim_enc) + + final_out = rearrange( + dim_enc, "(b seg_num) ts_d d_model -> b ts_d seg_num d_model", b=batch + ) + + return final_out + + +class SegMerging(nn.Module): + def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm): + super().__init__() + self.d_model = d_model + self.win_size = win_size + self.linear_trans = nn.Linear(win_size * d_model, d_model) + self.norm = norm_layer(win_size * d_model) + + def forward(self, x): + batch_size, ts_d, seg_num, d_model = x.shape + pad_num = seg_num % self.win_size + if pad_num != 0: + pad_num = self.win_size - pad_num + x = torch.cat((x, x[:, :, -pad_num:, :]), dim=-2) + + seg_to_merge = [] + for i in range(self.win_size): + seg_to_merge.append(x[:, :, i :: self.win_size, :]) + x = torch.cat(seg_to_merge, -1) + + x = self.norm(x) + x = self.linear_trans(x) + + return x + + +class ScaleBlock(nn.Module): + def __init__( + self, + win_size, + d_model, + n_heads, + d_ff, + depth, + dropout, + seg_num=10, + factor=10, + ): + super().__init__() + + if win_size > 1: + self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm) + else: + self.merge_layer = None + + self.encode_layers = nn.ModuleList() + + for i in range(depth): + self.encode_layers.append( + TwoStageAttentionLayer(seg_num, factor, d_model, n_heads, d_ff, dropout) + ) + + def forward(self, x, attn_mask=None, tau=None, delta=None): + _, ts_dim, _, _ = x.shape + + if self.merge_layer is not None: + x = self.merge_layer(x) + + for layer in self.encode_layers: + x = layer(x) + + return x, None + + +class CrossformerEncoder(nn.Module): + def __init__(self, attn_layers): + super().__init__() + self.encode_blocks = nn.ModuleList(attn_layers) + + def forward(self, x): + encode_x = [] + encode_x.append(x) + + for block in self.encode_blocks: + x, attns = block(x) + encode_x.append(x) + + return encode_x, None + + +class CrossformerDecoderLayer(nn.Module): + def __init__( + self, self_attention, cross_attention, seg_len, d_model, d_ff=None, dropout=0.1 + ): + super().__init__() + self.self_attention = self_attention + self.cross_attention = cross_attention + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.MLP1 = nn.Sequential( + nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, d_model) + ) + self.linear_pred = nn.Linear(d_model, seg_len) + + def forward(self, x, cross): + batch = x.shape[0] + x = self.self_attention(x) + x = rearrange(x, "b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model") + + cross = rearrange( + cross, "b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model" + ) + tmp, attn = self.cross_attention( + x, + cross, + cross, + None, + None, + None, + ) + x = x + self.dropout(tmp) + y = x = self.norm1(x) + y = self.MLP1(y) + dec_output = self.norm2(x + y) + + dec_output = rearrange( + dec_output, + "(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model", + b=batch, + ) + layer_predict = self.linear_pred(dec_output) + layer_predict = rearrange( + layer_predict, "b out_d seg_num seg_len -> b (out_d seg_num) seg_len" + ) + + return dec_output, layer_predict + + +class CrossformerDecoder(nn.Module): + def __init__(self, layers): + super().__init__() + self.decode_layers = nn.ModuleList(layers) + + def forward(self, x, cross): + final_predict = None + i = 0 + + ts_d = x.shape[1] + for layer in self.decode_layers: + cross_enc = cross[i] + x, layer_predict = layer(x, cross_enc) + if final_predict is None: + final_predict = layer_predict + else: + final_predict = final_predict + layer_predict + i += 1 + + final_predict = rearrange( + final_predict, + "b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d", + out_d=ts_d, + ) + + return final_predict diff --git a/tests/imputation/crossformer.py b/tests/imputation/crossformer.py new file mode 100644 index 00000000..fe8f6467 --- /dev/null +++ b/tests/imputation/crossformer.py @@ -0,0 +1,131 @@ +""" +Test cases for Crossformer imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import Crossformer +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + H5_TRAIN_SET_PATH, + H5_VAL_SET_PATH, + H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestCrossformer(unittest.TestCase): + logger.info("Running tests for an imputation model Crossformer...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "Crossformer") + model_save_name = "saved_crossformer_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a Crossformer model + crossformer = Crossformer( + DATA["n_steps"], + DATA["n_features"], + n_layers=2, + n_heads=2, + d_model=128, + d_ffn=256, + factor=10, + seg_len=12, + win_size=2, + dropout=0, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-crossformer") + def test_0_fit(self): + self.crossformer.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-crossformer") + def test_1_impute(self): + imputation_results = self.crossformer.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Crossformer test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-crossformer") + def test_2_parameters(self): + assert hasattr(self.crossformer, "model") and self.crossformer.model is not None + + assert ( + hasattr(self.crossformer, "optimizer") + and self.crossformer.optimizer is not None + ) + + assert hasattr(self.crossformer, "best_loss") + self.assertNotEqual(self.crossformer.best_loss, float("inf")) + + assert ( + hasattr(self.crossformer, "best_model_dict") + and self.crossformer.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-crossformer") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.crossformer) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.crossformer.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.crossformer.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-crossformer") + def test_4_lazy_loading(self): + self.crossformer.fit(H5_TRAIN_SET_PATH, H5_VAL_SET_PATH) + imputation_results = self.crossformer.predict(H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading Crossformer test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main() From 09db14ca5132f7d0cc1d8759b42a6dfac7d923ac Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 30 Mar 2024 23:32:15 +0800 Subject: [PATCH 08/12] feat: add FEDformer as an imputation model; --- pypots/imputation/__init__.py | 2 + pypots/imputation/fedformer/__init__.py | 17 + pypots/imputation/fedformer/data.py | 24 + pypots/imputation/fedformer/model.py | 340 +++++++ .../imputation/fedformer/modules/__init__.py | 6 + pypots/imputation/fedformer/modules/core.py | 100 ++ .../fedformer/modules/submodules.py | 915 ++++++++++++++++++ tests/imputation/fedformer.py | 132 +++ 8 files changed, 1536 insertions(+) create mode 100644 pypots/imputation/fedformer/__init__.py create mode 100644 pypots/imputation/fedformer/data.py create mode 100644 pypots/imputation/fedformer/model.py create mode 100644 pypots/imputation/fedformer/modules/__init__.py create mode 100644 pypots/imputation/fedformer/modules/core.py create mode 100644 pypots/imputation/fedformer/modules/submodules.py create mode 100644 tests/imputation/fedformer.py diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 0f41b3f7..a7052dcc 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -14,6 +14,7 @@ from .transformer import Transformer from .timesnet import TimesNet from .etsformer import ETSformer +from .fedformer import FEDformer from .crossformer import Crossformer from .autoformer import Autoformer from .dlinear import DLinear @@ -30,6 +31,7 @@ "SAITS", "Transformer", "ETSformer", + "FEDformer", "Crossformer", "TimesNet", "PatchTST", diff --git a/pypots/imputation/fedformer/__init__.py b/pypots/imputation/fedformer/__init__.py new file mode 100644 index 00000000..049302f2 --- /dev/null +++ b/pypots/imputation/fedformer/__init__.py @@ -0,0 +1,17 @@ +""" +The package of the partially-observed time-series imputation model FEDformer. + +Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021). +FEDformer: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.". + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import FEDformer + +__all__ = [ + "FEDformer", +] diff --git a/pypots/imputation/fedformer/data.py b/pypots/imputation/fedformer/data.py new file mode 100644 index 00000000..a5982636 --- /dev/null +++ b/pypots/imputation/fedformer/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for FEDformer. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForFEDformer(DatasetForSAITS): + """Actually FEDformer uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_labels: bool, + file_type: str = "h5py", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_labels, file_type, rate) diff --git a/pypots/imputation/fedformer/model.py b/pypots/imputation/fedformer/model.py new file mode 100644 index 00000000..4d7182cc --- /dev/null +++ b/pypots/imputation/fedformer/model.py @@ -0,0 +1,340 @@ +""" +The implementation of FEDformer for the partially-observed time-series imputation task. + +Refer to the paper "Zhou, T., Ma, Z., Wen, Q., Wang, X., Sun, L., & Jin, R. (2022). +FEDformer: Frequency enhanced decomposed transformer for long-term series forecasting. ICML 2022.". + +Notes +----- +Partial implementation uses code from https://github.com/MAZiqing/FEDformer + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .data import DatasetForFEDformer +from .modules.core import _FEDformer +from ..base import BaseNNImputer +from ...data.base import BaseDataset +from ...data.checking import check_X_ori_in_val_set +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.logging import logger + + +class FEDformer(BaseNNImputer): + """The PyTorch implementation of the FEDformer model. + FEDformer is originally proposed by Woo et al. in :cite:`zhou2022fedformer`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + n_layers : + The number of layers in the FEDformer. + + n_heads : + The number of heads in the multi-head attention mechanism. + + d_model : + The dimension of the model. + + d_ffn : + The dimension of the feed-forward network. + + moving_avg_window_size : + The window size of moving average. + + dropout : + The dropout rate for the model. + + version : + The version of the model. It has to be one of ["Wavelets", "Fourier"]. + The default value is "Fourier". + + modes : + The number of modes to be selected. The default value is 32. + + mode_select : + Get modes on frequency domain. It has to "random" or "low". The default value is "random". + 'random' means sampling randomly; 'low' means sampling the lowest modes; + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + References + ---------- + .. [1] `Zhou, Tian, Ziqing Ma, Qingsong Wen, Xue Wang, Liang Sun, and Rong Jin. + "FEDformer: Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting". + ICML 2022. + `_ + + """ + + def __init__( + self, + n_steps, + n_features, + n_layers, + n_heads, + d_model, + d_ffn, + moving_avg_window_size, + dropout: float = 0, + version="Fourier", + modes=32, + mode_select="random", + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.n_layers = n_layers + self.n_heads = n_heads + self.d_model = d_model + self.d_ffn = d_ffn + self.modes = modes + self.mode_select = mode_select + self.moving_avg_window_size = moving_avg_window_size + self.dropout = dropout + self.version = version + + # set up the model + self.model = _FEDformer( + self.n_steps, + self.n_features, + self.n_layers, + self.n_heads, + self.d_model, + self.d_ffn, + self.moving_avg_window_size, + self.dropout, + self.version, + self.modes, + self.mode_select, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForFEDformer( + train_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not check_X_ori_in_val_set(val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForFEDformer( + val_set, return_X_ori=True, return_labels=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : str + The type of the given file if test_set is a path string. + + Returns + ------- + result_dict : dict, + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Warnings + -------- + The method impute is deprecated. Please use `predict()` instead. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (time steps), n_features], + Imputed data. + """ + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/fedformer/modules/__init__.py b/pypots/imputation/fedformer/modules/__init__.py new file mode 100644 index 00000000..ceaa7ee3 --- /dev/null +++ b/pypots/imputation/fedformer/modules/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause diff --git a/pypots/imputation/fedformer/modules/core.py b/pypots/imputation/fedformer/modules/core.py new file mode 100644 index 00000000..2be002a9 --- /dev/null +++ b/pypots/imputation/fedformer/modules/core.py @@ -0,0 +1,100 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from .submodules import MultiWaveletTransform, FourierBlock +from ...autoformer.modules.submodules import ( + AutoformerEncoder, + AutoformerEncoderLayer, + AutoCorrelationLayer, + SeasonalLayerNorm, +) +from ...timesnet.modules.embedding import DataEmbedding +from ....utils.metrics import calc_mse + + +class _FEDformer(nn.Module): + def __init__( + self, + n_steps, + n_features, + n_layers, + n_heads, + d_model, + d_ffn, + moving_avg_window_size, + dropout, + version="Fourier", + modes=32, + mode_select="random", + activation="relu", + ): + super().__init__() + + self.enc_embedding = DataEmbedding( + n_features, + d_model, + dropout=dropout, + ) + + if version == "Wavelets": + encoder_self_att = MultiWaveletTransform(ich=d_model, L=1, base="legendre") + elif version == "Fourier": + encoder_self_att = FourierBlock( + in_channels=d_model, + out_channels=d_model, + seq_len=n_steps, + modes=modes, + mode_select_method=mode_select, + ) + else: + raise ValueError( + f"Unsupported version: {version}. Please choose from ['Wavelets', 'Fourier']." + ) + + self.encoder = AutoformerEncoder( + [ + AutoformerEncoderLayer( + AutoCorrelationLayer( + encoder_self_att, # instead of multi-head attention in transformer + d_model, + n_heads, + ), + d_model, + d_ffn, + moving_avg_window_size, + dropout, + activation, + ) + for _ in range(n_layers) + ], + norm_layer=SeasonalLayerNorm(d_model), + ) + self.projection = nn.Linear(d_model, n_features) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, masks = inputs["X"], inputs["missing_mask"] + + # embedding + enc_out = self.enc_embedding(X) + + # FEDformer encoder processing + enc_out, attns = self.encoder(enc_out) + output = self.projection(enc_out) + + imputed_data = masks * X + (1 - masks) * output + results = { + "imputed_data": imputed_data, + } + + if training: + # `loss` is always the item for backward propagating to update the model + loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) + results["loss"] = loss + + return results diff --git a/pypots/imputation/fedformer/modules/submodules.py b/pypots/imputation/fedformer/modules/submodules.py new file mode 100644 index 00000000..24ad5907 --- /dev/null +++ b/pypots/imputation/fedformer/modules/submodules.py @@ -0,0 +1,915 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import math +from functools import partial +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.special import eval_legendre +from sympy import Poly, legendre, Symbol, chebyshevt +from torch import Tensor +from torch import nn + + +def legendreDer(k, x): + def _legendre(k, x): + return (2 * k + 1) * eval_legendre(k, x) + + out = 0 + for i in np.arange(k - 1, -1, -2): + out += _legendre(i, x) + return out + + +def phi_(phi_c, x, lb=0, ub=1): + mask = np.logical_or(x < lb, x > ub) * 1.0 + return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1 - mask) + + +def get_phi_psi(k, base): + x = Symbol("x") + phi_coeff = np.zeros((k, k)) + phi_2x_coeff = np.zeros((k, k)) + if base == "legendre": + for ki in range(k): + coeff_ = Poly(legendre(ki, 2 * x - 1), x).all_coeffs() + phi_coeff[ki, : ki + 1] = np.flip( + np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64) + ) + coeff_ = Poly(legendre(ki, 4 * x - 1), x).all_coeffs() + phi_2x_coeff[ki, : ki + 1] = np.flip( + np.sqrt(2) * np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64) + ) + + psi1_coeff = np.zeros((k, k)) + psi2_coeff = np.zeros((k, k)) + for ki in range(k): + psi1_coeff[ki, :] = phi_2x_coeff[ki, :] + for i in range(k): + a = phi_2x_coeff[ki, : ki + 1] + b = phi_coeff[i, : i + 1] + prod_ = np.convolve(a, b) + prod_[np.abs(prod_) < 1e-8] = 0 + proj_ = ( + prod_ + * 1 + / (np.arange(len(prod_)) + 1) + * np.power(0.5, 1 + np.arange(len(prod_))) + ).sum() + psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :] + psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :] + for j in range(ki): + a = phi_2x_coeff[ki, : ki + 1] + b = psi1_coeff[j, :] + prod_ = np.convolve(a, b) + prod_[np.abs(prod_) < 1e-8] = 0 + proj_ = ( + prod_ + * 1 + / (np.arange(len(prod_)) + 1) + * np.power(0.5, 1 + np.arange(len(prod_))) + ).sum() + psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :] + psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :] + + a = psi1_coeff[ki, :] + prod_ = np.convolve(a, a) + prod_[np.abs(prod_) < 1e-8] = 0 + norm1 = ( + prod_ + * 1 + / (np.arange(len(prod_)) + 1) + * np.power(0.5, 1 + np.arange(len(prod_))) + ).sum() + + a = psi2_coeff[ki, :] + prod_ = np.convolve(a, a) + prod_[np.abs(prod_) < 1e-8] = 0 + norm2 = ( + prod_ + * 1 + / (np.arange(len(prod_)) + 1) + * (1 - np.power(0.5, 1 + np.arange(len(prod_)))) + ).sum() + norm_ = np.sqrt(norm1 + norm2) + psi1_coeff[ki, :] /= norm_ + psi2_coeff[ki, :] /= norm_ + psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0 + psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0 + + phi = [np.poly1d(np.flip(phi_coeff[i, :])) for i in range(k)] + psi1 = [np.poly1d(np.flip(psi1_coeff[i, :])) for i in range(k)] + psi2 = [np.poly1d(np.flip(psi2_coeff[i, :])) for i in range(k)] + + elif base == "chebyshev": + for ki in range(k): + if ki == 0: + phi_coeff[ki, : ki + 1] = np.sqrt(2 / np.pi) + phi_2x_coeff[ki, : ki + 1] = np.sqrt(2 / np.pi) * np.sqrt(2) + else: + coeff_ = Poly(chebyshevt(ki, 2 * x - 1), x).all_coeffs() + phi_coeff[ki, : ki + 1] = np.flip( + 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64) + ) + coeff_ = Poly(chebyshevt(ki, 4 * x - 1), x).all_coeffs() + phi_2x_coeff[ki, : ki + 1] = np.flip( + np.sqrt(2) + * 2 + / np.sqrt(np.pi) + * np.array(coeff_).astype(np.float64) + ) + + phi = [partial(phi_, phi_coeff[i, :]) for i in range(k)] + + x = Symbol("x") + kUse = 2 * k + roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots() + x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) + # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) + # not needed for our purpose here, we use even k always to avoid + wm = np.pi / kUse / 2 + + psi1_coeff = np.zeros((k, k)) + psi2_coeff = np.zeros((k, k)) + + psi1 = [[] for _ in range(k)] + psi2 = [[] for _ in range(k)] + + for ki in range(k): + psi1_coeff[ki, :] = phi_2x_coeff[ki, :] + for i in range(k): + proj_ = (wm * phi[i](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum() + psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :] + psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :] + + for j in range(ki): + proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum() + psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :] + psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :] + + psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5) + psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5, ub=1) + + norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum() + norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum() + + norm_ = np.sqrt(norm1 + norm2) + psi1_coeff[ki, :] /= norm_ + psi2_coeff[ki, :] /= norm_ + psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0 + psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0 + + psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5 + 1e-16) + psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5 + 1e-16, ub=1) + + return phi, psi1, psi2 + + +def get_filter(base, k): + def psi(psi1, psi2, i, inp): + mask = (inp <= 0.5) * 1.0 + return psi1[i](inp) * mask + psi2[i](inp) * (1 - mask) + + if base not in ["legendre", "chebyshev"]: + raise Exception("Base not supported") + + x = Symbol("x") + H0 = np.zeros((k, k)) + H1 = np.zeros((k, k)) + G0 = np.zeros((k, k)) + G1 = np.zeros((k, k)) + PHI0 = np.zeros((k, k)) + PHI1 = np.zeros((k, k)) + phi, psi1, psi2 = get_phi_psi(k, base) + if base == "legendre": + roots = Poly(legendre(k, 2 * x - 1)).all_roots() + x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) + wm = 1 / k / legendreDer(k, 2 * x_m - 1) / eval_legendre(k - 1, 2 * x_m - 1) + + for ki in range(k): + for kpi in range(k): + H0[ki, kpi] = ( + 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum() + ) + G0[ki, kpi] = ( + 1 + / np.sqrt(2) + * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum() + ) + H1[ki, kpi] = ( + 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum() + ) + G1[ki, kpi] = ( + 1 + / np.sqrt(2) + * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum() + ) + + PHI0 = np.eye(k) + PHI1 = np.eye(k) + + elif base == "chebyshev": + x = Symbol("x") + kUse = 2 * k + roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots() + x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) + # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) + # not needed for our purpose here, we use even k always to avoid + wm = np.pi / kUse / 2 + + for ki in range(k): + for kpi in range(k): + H0[ki, kpi] = ( + 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum() + ) + G0[ki, kpi] = ( + 1 + / np.sqrt(2) + * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum() + ) + H1[ki, kpi] = ( + 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum() + ) + G1[ki, kpi] = ( + 1 + / np.sqrt(2) + * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum() + ) + + PHI0[ki, kpi] = (wm * phi[ki](2 * x_m) * phi[kpi](2 * x_m)).sum() * 2 + PHI1[ki, kpi] = ( + wm * phi[ki](2 * x_m - 1) * phi[kpi](2 * x_m - 1) + ).sum() * 2 + + PHI0[np.abs(PHI0) < 1e-8] = 0 + PHI1[np.abs(PHI1) < 1e-8] = 0 + + H0[np.abs(H0) < 1e-8] = 0 + H1[np.abs(H1) < 1e-8] = 0 + G0[np.abs(G0) < 1e-8] = 0 + G1[np.abs(G1) < 1e-8] = 0 + + return H0, H1, G0, G1, PHI0, PHI1 + + +class sparseKernelFT1d(nn.Module): + def __init__(self, k, alpha, c=1, nl=1, initializer=None, **kwargs): + super().__init__() + + self.modes1 = alpha + self.scale = 1 / (c * k * c * k) + self.weights1 = nn.Parameter( + self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float) + ) + self.weights2 = nn.Parameter( + self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float) + ) + self.weights1.requires_grad = True + self.weights2.requires_grad = True + self.k = k + + def compl_mul1d(self, order, x, weights): + x_flag = True + w_flag = True + if not torch.is_complex(x): + x_flag = False + x = torch.complex(x, torch.zeros_like(x).to(x.device)) + if not torch.is_complex(weights): + w_flag = False + weights = torch.complex( + weights, torch.zeros_like(weights).to(weights.device) + ) + if x_flag or w_flag: + return torch.complex( + torch.einsum(order, x.real, weights.real) + - torch.einsum(order, x.imag, weights.imag), + torch.einsum(order, x.real, weights.imag) + + torch.einsum(order, x.imag, weights.real), + ) + else: + return torch.einsum(order, x.real, weights.real) + + def forward(self, x): + B, N, c, k = x.shape # (B, N, c, k) + + x = x.view(B, N, -1) + x = x.permute(0, 2, 1) + x_fft = torch.fft.rfft(x) + # Multiply relevant Fourier modes + mode = min(self.modes1, N // 2 + 1) + out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat) + out_ft[:, :, :mode] = self.compl_mul1d( + "bix,iox->box", + x_fft[:, :, :mode], + torch.complex(self.weights1, self.weights2)[:, :, :mode], + ) + x = torch.fft.irfft(out_ft, n=N) + x = x.permute(0, 2, 1).view(B, N, c, k) + return x + + +class MWT_CZ1d(nn.Module): + def __init__( + self, k=3, alpha=64, L=0, c=1, base="legendre", initializer=None, **kwargs + ): + super().__init__() + + self.k = k + self.L = L + H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) + H0r = H0 @ PHI0 + G0r = G0 @ PHI0 + H1r = H1 @ PHI1 + G1r = G1 @ PHI1 + + H0r[np.abs(H0r) < 1e-8] = 0 + H1r[np.abs(H1r) < 1e-8] = 0 + G0r[np.abs(G0r) < 1e-8] = 0 + G1r[np.abs(G1r) < 1e-8] = 0 + self.max_item = 3 + + self.A = sparseKernelFT1d(k, alpha, c) + self.B = sparseKernelFT1d(k, alpha, c) + self.C = sparseKernelFT1d(k, alpha, c) + + self.T0 = nn.Linear(k, k) + + self.register_buffer("ec_s", torch.Tensor(np.concatenate((H0.T, H1.T), axis=0))) + self.register_buffer("ec_d", torch.Tensor(np.concatenate((G0.T, G1.T), axis=0))) + + self.register_buffer("rc_e", torch.Tensor(np.concatenate((H0r, G0r), axis=0))) + self.register_buffer("rc_o", torch.Tensor(np.concatenate((H1r, G1r), axis=0))) + + def forward(self, x): + B, N, c, k = x.shape # (B, N, k) + ns = math.floor(np.log2(N)) + nl = pow(2, math.ceil(np.log2(N))) + extra_x = x[:, 0 : nl - N, :, :] + x = torch.cat([x, extra_x], 1) + Ud = torch.jit.annotate(List[Tensor], []) + Us = torch.jit.annotate(List[Tensor], []) + for i in range(ns - self.L): + d, x = self.wavelet_transform(x) + Ud += [self.A(d) + self.B(x)] + Us += [self.C(d)] + x = self.T0(x) # coarsest scale transform + + # reconstruct + for i in range(ns - 1 - self.L, -1, -1): + x = x + Us[i] + x = torch.cat((x, Ud[i]), -1) + x = self.evenOdd(x) + x = x[:, :N, :, :] + + return x + + def wavelet_transform(self, x): + xa = torch.cat( + [ + x[:, ::2, :, :], + x[:, 1::2, :, :], + ], + -1, + ) + d = torch.matmul(xa, self.ec_d) + s = torch.matmul(xa, self.ec_s) + return d, s + + def evenOdd(self, x): + + B, N, c, ich = x.shape # (B, N, c, k) + assert ich == 2 * self.k + x_e = torch.matmul(x, self.rc_e) + x_o = torch.matmul(x, self.rc_o) + + x = torch.zeros(B, N * 2, c, self.k, device=x.device) + x[..., ::2, :, :] = x_e + x[..., 1::2, :, :] = x_o + return x + + +class MultiWaveletTransform(nn.Module): + """ + 1D multiwavelet block. + """ + + def __init__( + self, + ich=1, + k=8, + alpha=16, + c=128, + nCZ=1, + L=0, + base="legendre", + attention_dropout=0.1, + ): + super().__init__() + # print("base", base) + self.k = k + self.c = c + self.L = L + self.nCZ = nCZ + self.Lk0 = nn.Linear(ich, c * k) + self.Lk1 = nn.Linear(c * k, ich) + self.ich = ich + self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ)) + + def forward(self, queries, keys, values, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + if L > S: + zeros = torch.zeros_like(queries[:, : (L - S), :]).float() + values = torch.cat([values, zeros], dim=1) + keys = torch.cat([keys, zeros], dim=1) + else: + values = values[:, :L, :, :] + keys = keys[:, :L, :, :] + values = values.view(B, L, -1) + + V = self.Lk0(values).view(B, L, self.c, -1) + for i in range(self.nCZ): + V = self.MWT_CZ[i](V) + if i < self.nCZ - 1: + V = F.relu(V) + + V = self.Lk1(V.view(B, L, -1)) + V = V.view(B, L, -1, D) + return (V.contiguous(), None) + + +class FourierCrossAttentionW(nn.Module): + def __init__( + self, + in_channels, + out_channels, + seq_len_q, + seq_len_kv, + modes=16, + activation="tanh", + mode_select_method="random", + ): + super().__init__() + # print("corss fourier correlation used!") + self.in_channels = in_channels + self.out_channels = out_channels + self.modes1 = modes + self.activation = activation + + def compl_mul1d(self, order, x, weights): + x_flag = True + w_flag = True + if not torch.is_complex(x): + x_flag = False + x = torch.complex(x, torch.zeros_like(x).to(x.device)) + if not torch.is_complex(weights): + w_flag = False + weights = torch.complex( + weights, torch.zeros_like(weights).to(weights.device) + ) + if x_flag or w_flag: + return torch.complex( + torch.einsum(order, x.real, weights.real) + - torch.einsum(order, x.imag, weights.imag), + torch.einsum(order, x.real, weights.imag) + + torch.einsum(order, x.imag, weights.real), + ) + else: + return torch.einsum(order, x.real, weights.real) + + def forward(self, q, k, v, mask): + B, L, E, H = q.shape + + xq = q.permute(0, 3, 2, 1) # size = [B, H, E, L] torch.Size([3, 8, 64, 512]) + xk = k.permute(0, 3, 2, 1) + xv = v.permute(0, 3, 2, 1) + self.index_q = list(range(0, min(int(L // 2), self.modes1))) + self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1))) + + # Compute Fourier coefficients + xq_ft_ = torch.zeros( + B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat + ) + xq_ft = torch.fft.rfft(xq, dim=-1) + for i, j in enumerate(self.index_q): + xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] + + xk_ft_ = torch.zeros( + B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat + ) + xk_ft = torch.fft.rfft(xk, dim=-1) + for i, j in enumerate(self.index_k_v): + xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] + xqk_ft = self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_) + if self.activation == "tanh": + xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh()) + elif self.activation == "softmax": + xqk_ft = torch.softmax(abs(xqk_ft), dim=-1) + xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft)) + else: + raise Exception( + "{} actiation function is not implemented".format(self.activation) + ) + xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_) + + xqkvw = xqkv_ft + out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) + for i, j in enumerate(self.index_q): + out_ft[:, :, :, j] = xqkvw[:, :, :, i] + + out = torch.fft.irfft( + out_ft / self.in_channels / self.out_channels, n=xq.size(-1) + ).permute(0, 3, 2, 1) + # size = [B, L, H, E] + return (out, None) + + +class MultiWaveletCross(nn.Module): + """ + 1D Multiwavelet Cross Attention layer. + """ + + def __init__( + self, + in_channels, + out_channels, + seq_len_q, + seq_len_kv, + modes, + c=64, + k=8, + ich=512, + L=0, + base="legendre", + mode_select_method="random", + initializer=None, + activation="tanh", + **kwargs, + ): + super().__init__() + # print("base", base) + + self.c = c + self.k = k + self.L = L + H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) + H0r = H0 @ PHI0 + G0r = G0 @ PHI0 + H1r = H1 @ PHI1 + G1r = G1 @ PHI1 + + H0r[np.abs(H0r) < 1e-8] = 0 + H1r[np.abs(H1r) < 1e-8] = 0 + G0r[np.abs(G0r) < 1e-8] = 0 + G1r[np.abs(G1r) < 1e-8] = 0 + self.max_item = 3 + + self.attn1 = FourierCrossAttentionW( + in_channels=in_channels, + out_channels=out_channels, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, + modes=modes, + activation=activation, + mode_select_method=mode_select_method, + ) + self.attn2 = FourierCrossAttentionW( + in_channels=in_channels, + out_channels=out_channels, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, + modes=modes, + activation=activation, + mode_select_method=mode_select_method, + ) + self.attn3 = FourierCrossAttentionW( + in_channels=in_channels, + out_channels=out_channels, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, + modes=modes, + activation=activation, + mode_select_method=mode_select_method, + ) + self.attn4 = FourierCrossAttentionW( + in_channels=in_channels, + out_channels=out_channels, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, + modes=modes, + activation=activation, + mode_select_method=mode_select_method, + ) + self.T0 = nn.Linear(k, k) + self.register_buffer("ec_s", torch.Tensor(np.concatenate((H0.T, H1.T), axis=0))) + self.register_buffer("ec_d", torch.Tensor(np.concatenate((G0.T, G1.T), axis=0))) + + self.register_buffer("rc_e", torch.Tensor(np.concatenate((H0r, G0r), axis=0))) + self.register_buffer("rc_o", torch.Tensor(np.concatenate((H1r, G1r), axis=0))) + + self.Lk = nn.Linear(ich, c * k) + self.Lq = nn.Linear(ich, c * k) + self.Lv = nn.Linear(ich, c * k) + self.out = nn.Linear(c * k, ich) + self.modes1 = modes + + def forward(self, q, k, v, mask=None): + B, N, H, E = q.shape # (B, N, H, E) torch.Size([3, 768, 8, 2]) + _, S, _, _ = k.shape # (B, S, H, E) torch.Size([3, 96, 8, 2]) + + q = q.view(q.shape[0], q.shape[1], -1) + k = k.view(k.shape[0], k.shape[1], -1) + v = v.view(v.shape[0], v.shape[1], -1) + q = self.Lq(q) + q = q.view(q.shape[0], q.shape[1], self.c, self.k) + k = self.Lk(k) + k = k.view(k.shape[0], k.shape[1], self.c, self.k) + v = self.Lv(v) + v = v.view(v.shape[0], v.shape[1], self.c, self.k) + + if N > S: + zeros = torch.zeros_like(q[:, : (N - S), :]).float() + v = torch.cat([v, zeros], dim=1) + k = torch.cat([k, zeros], dim=1) + else: + v = v[:, :N, :, :] + k = k[:, :N, :, :] + + ns = math.floor(np.log2(N)) + nl = pow(2, math.ceil(np.log2(N))) + extra_q = q[:, 0 : nl - N, :, :] + extra_k = k[:, 0 : nl - N, :, :] + extra_v = v[:, 0 : nl - N, :, :] + q = torch.cat([q, extra_q], 1) + k = torch.cat([k, extra_k], 1) + v = torch.cat([v, extra_v], 1) + + Ud_q = torch.jit.annotate(List[Tuple[Tensor]], []) + Ud_k = torch.jit.annotate(List[Tuple[Tensor]], []) + Ud_v = torch.jit.annotate(List[Tuple[Tensor]], []) + + Us_q = torch.jit.annotate(List[Tensor], []) + Us_k = torch.jit.annotate(List[Tensor], []) + Us_v = torch.jit.annotate(List[Tensor], []) + + Ud = torch.jit.annotate(List[Tensor], []) + Us = torch.jit.annotate(List[Tensor], []) + + # decompose + for i in range(ns - self.L): + d, q = self.wavelet_transform(q) + Ud_q += [tuple([d, q])] + Us_q += [d] + for i in range(ns - self.L): + d, k = self.wavelet_transform(k) + Ud_k += [tuple([d, k])] + Us_k += [d] + for i in range(ns - self.L): + d, v = self.wavelet_transform(v) + Ud_v += [tuple([d, v])] + Us_v += [d] + for i in range(ns - self.L): + dk, sk = Ud_k[i], Us_k[i] + dq, sq = Ud_q[i], Us_q[i] + dv, sv = Ud_v[i], Us_v[i] + Ud += [ + self.attn1(dq[0], dk[0], dv[0], mask)[0] + + self.attn2(dq[1], dk[1], dv[1], mask)[0] + ] + Us += [self.attn3(sq, sk, sv, mask)[0]] + v = self.attn4(q, k, v, mask)[0] + + # reconstruct + for i in range(ns - 1 - self.L, -1, -1): + v = v + Us[i] + v = torch.cat((v, Ud[i]), -1) + v = self.evenOdd(v) + v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1)) + return (v.contiguous(), None) + + def wavelet_transform(self, x): + xa = torch.cat( + [ + x[:, ::2, :, :], + x[:, 1::2, :, :], + ], + -1, + ) + d = torch.matmul(xa, self.ec_d) + s = torch.matmul(xa, self.ec_s) + return d, s + + def evenOdd(self, x): + B, N, c, ich = x.shape # (B, N, c, k) + assert ich == 2 * self.k + x_e = torch.matmul(x, self.rc_e) + x_o = torch.matmul(x, self.rc_o) + + x = torch.zeros(B, N * 2, c, self.k, device=x.device) + x[..., ::2, :, :] = x_e + x[..., 1::2, :, :] = x_o + return x + + +def get_frequency_modes(seq_len, modes=64, mode_select_method="random"): + """ + get modes on frequency domain: + 'random' means sampling randomly; + 'else' means sampling the lowest modes; + """ + modes = min(modes, seq_len // 2) + if mode_select_method == "random": + index = list(range(0, seq_len // 2)) + np.random.shuffle(index) + index = index[:modes] + else: + index = list(range(0, modes)) + index.sort() + return index + + +# ########## fourier layer ############# +class FourierBlock(nn.Module): + def __init__( + self, in_channels, out_channels, seq_len, modes=0, mode_select_method="random" + ): + super().__init__() + # print("fourier enhanced block used!") + """ + 1D Fourier block. It performs representation learning on frequency domain, + it does FFT, linear transform, and Inverse FFT. + """ + # get modes on frequency domain + self.index = get_frequency_modes( + seq_len, modes=modes, mode_select_method=mode_select_method + ) + # print("modes={}, index={}".format(modes, self.index)) + + self.scale = 1 / (in_channels * out_channels) + self.weights1 = nn.Parameter( + self.scale + * torch.rand( + 8, + in_channels // 8, + out_channels // 8, + len(self.index), + dtype=torch.cfloat, + ) + ) + + # Complex multiplication + def compl_mul1d(self, input, weights): + # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) + return torch.einsum("bhi,hio->bho", input, weights) + + def forward(self, q, k, v, mask): + # size = [B, L, H, E] + B, L, H, E = q.shape + x = q.permute(0, 2, 3, 1) + # Compute Fourier coefficients + x_ft = torch.fft.rfft(x, dim=-1) + # Perform Fourier neural operations + out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat) + for wi, i in enumerate(self.index): + out_ft[:, :, :, wi] = self.compl_mul1d( + x_ft[:, :, :, i], self.weights1[:, :, :, wi] + ) + # Return to time domain + x = torch.fft.irfft(out_ft, n=x.size(-1)) + return (x, None) + + +# ########## Fourier Cross Former #################### +class FourierCrossAttention(nn.Module): + def __init__( + self, + in_channels, + out_channels, + seq_len_q, + seq_len_kv, + modes=64, + mode_select_method="random", + activation="tanh", + policy=0, + num_heads=8, + ): + super().__init__() + # print("fourier enhanced cross attention used!") + """ + 1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT. + """ + self.activation = activation + self.in_channels = in_channels + self.out_channels = out_channels + # get modes for queries and keys (& values) on frequency domain + self.index_q = get_frequency_modes( + seq_len_q, modes=modes, mode_select_method=mode_select_method + ) + self.index_kv = get_frequency_modes( + seq_len_kv, modes=modes, mode_select_method=mode_select_method + ) + + # print("modes_q={}, index_q={}".format(len(self.index_q), self.index_q)) + # print("modes_kv={}, index_kv={}".format(len(self.index_kv), self.index_kv)) + + self.scale = 1 / (in_channels * out_channels) + self.weights1 = nn.Parameter( + self.scale + * torch.rand( + num_heads, + in_channels // num_heads, + out_channels // num_heads, + len(self.index_q), + dtype=torch.float, + ) + ) + self.weights2 = nn.Parameter( + self.scale + * torch.rand( + num_heads, + in_channels // num_heads, + out_channels // num_heads, + len(self.index_q), + dtype=torch.float, + ) + ) + + # Complex multiplication + def compl_mul1d(self, order, x, weights): + x_flag = True + w_flag = True + if not torch.is_complex(x): + x_flag = False + x = torch.complex(x, torch.zeros_like(x).to(x.device)) + if not torch.is_complex(weights): + w_flag = False + weights = torch.complex( + weights, torch.zeros_like(weights).to(weights.device) + ) + if x_flag or w_flag: + return torch.complex( + torch.einsum(order, x.real, weights.real) + - torch.einsum(order, x.imag, weights.imag), + torch.einsum(order, x.real, weights.imag) + + torch.einsum(order, x.imag, weights.real), + ) + else: + return torch.einsum(order, x.real, weights.real) + + def forward(self, q, k, v, mask): + # size = [B, L, H, E] + B, L, H, E = q.shape + xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L] + xk = k.permute(0, 2, 3, 1) + # xv = v.permute(0, 2, 3, 1) + + # Compute Fourier coefficients + xq_ft_ = torch.zeros( + B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat + ) + xq_ft = torch.fft.rfft(xq, dim=-1) + for i, j in enumerate(self.index_q): + if j >= xq_ft.shape[3]: + continue + xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] + xk_ft_ = torch.zeros( + B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat + ) + xk_ft = torch.fft.rfft(xk, dim=-1) + for i, j in enumerate(self.index_kv): + if j >= xk_ft.shape[3]: + continue + xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] + + # perform attention mechanism on frequency domain + xqk_ft = self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_) + if self.activation == "tanh": + xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh()) + elif self.activation == "softmax": + xqk_ft = torch.softmax(abs(xqk_ft), dim=-1) + xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft)) + else: + raise Exception( + "{} actiation function is not implemented".format(self.activation) + ) + xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_) + xqkvw = self.compl_mul1d( + "bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2) + ) + out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) + for i, j in enumerate(self.index_q): + if i >= xqkvw.shape[3] or j >= out_ft.shape[3]: + continue + out_ft[:, :, :, j] = xqkvw[:, :, :, i] + # Return to time domain + out = torch.fft.irfft( + out_ft / self.in_channels / self.out_channels, n=xq.size(-1) + ) + return (out, None) diff --git a/tests/imputation/fedformer.py b/tests/imputation/fedformer.py new file mode 100644 index 00000000..e1cc54f4 --- /dev/null +++ b/tests/imputation/fedformer.py @@ -0,0 +1,132 @@ +""" +Test cases for FEDformer imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import FEDformer +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + H5_TRAIN_SET_PATH, + H5_VAL_SET_PATH, + H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestFEDformer(unittest.TestCase): + logger.info("Running tests for an imputation model FEDformer...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "FEDformer") + model_save_name = "saved_fedformer_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a FEDformer model + fedformer = FEDformer( + DATA["n_steps"], + DATA["n_features"], + n_layers=2, + n_heads=2, + d_model=128, + d_ffn=256, + moving_avg_window_size=3, + dropout=0, + version="Fourier", + modes=32, + mode_select="random", + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-fedformer") + def test_0_fit(self): + self.fedformer.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-fedformer") + def test_1_impute(self): + imputation_results = self.fedformer.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"FEDformer test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-fedformer") + def test_2_parameters(self): + assert hasattr(self.fedformer, "model") and self.fedformer.model is not None + + assert ( + hasattr(self.fedformer, "optimizer") + and self.fedformer.optimizer is not None + ) + + assert hasattr(self.fedformer, "best_loss") + self.assertNotEqual(self.fedformer.best_loss, float("inf")) + + assert ( + hasattr(self.fedformer, "best_model_dict") + and self.fedformer.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-fedformer") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.fedformer) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.fedformer.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.fedformer.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-fedformer") + def test_4_lazy_loading(self): + self.fedformer.fit(H5_TRAIN_SET_PATH, H5_VAL_SET_PATH) + imputation_results = self.fedformer.predict(H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading FEDformer test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main() From 777b64fcd93fcef7f0651bd8f36cc423d7e7c73a Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 30 Mar 2024 23:33:56 +0800 Subject: [PATCH 09/12] fix: Fourier version FEDformer does not work for now, use Wavelets; --- tests/imputation/fedformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/imputation/fedformer.py b/tests/imputation/fedformer.py index e1cc54f4..5d32fc6a 100644 --- a/tests/imputation/fedformer.py +++ b/tests/imputation/fedformer.py @@ -51,7 +51,7 @@ class TestFEDformer(unittest.TestCase): d_ffn=256, moving_avg_window_size=3, dropout=0, - version="Fourier", + version="Wavelets", # TODO: Fourier version does not work modes=32, mode_select="random", epochs=EPOCHS, From e61a3e80aeea28c3434626d08998e0441ee9837c Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sun, 31 Mar 2024 01:40:43 +0800 Subject: [PATCH 10/12] depen: add sympy as a dependency; --- environment-dev.yml | 3 ++- requirements.txt | 1 + setup.cfg | 11 ++++++----- setup.py | 1 + tests/environment_for_conda_test.yml | 9 +++++---- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/environment-dev.yml b/environment-dev.yml index d3a47627..2e22b173 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -12,6 +12,7 @@ dependencies: - conda-forge::h5py - conda-forge::numpy - conda-forge::scipy + - conda-forge::sympy - conda-forge::python - conda-forge::einops - conda-forge::pandas @@ -26,8 +27,8 @@ dependencies: # optional - pyg::pyg - - pyg::pytorch-scatter - pyg::pytorch-sparse + - pyg::pytorch-scatter # test - conda-forge::pytest-cov diff --git a/requirements.txt b/requirements.txt index cfbab72f..1435b245 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ h5py numpy scipy +sympy einops pandas matplotlib diff --git a/setup.cfg b/setup.cfg index a36169c0..1d322e96 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,15 +25,16 @@ exclude = pypots/*/template # basic dependencies basic = + h5py numpy - scikit-learn - matplotlib + scipy + sympy einops pandas - torch>=1.10.0 + matplotlib tensorboard - scipy - h5py + scikit-learn + torch>=1.10.0 tsdb>=0.2 pygrinder>=0.4 diff --git a/setup.py b/setup.py index 92011eeb..986910c6 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ "h5py", "numpy", "scipy", + "sympy", "einops", "pandas", "matplotlib", diff --git a/tests/environment_for_conda_test.yml b/tests/environment_for_conda_test.yml index 54f07925..f603e806 100644 --- a/tests/environment_for_conda_test.yml +++ b/tests/environment_for_conda_test.yml @@ -10,23 +10,24 @@ dependencies: # basic - conda-forge::python - conda-forge::pip + - conda-forge::h5py - conda-forge::scipy + - conda-forge::sympy - conda-forge::numpy - - conda-forge::scikit-learn - conda-forge::einops - conda-forge::pandas - - conda-forge::h5py + - conda-forge::matplotlib - conda-forge::tensorboard + - conda-forge::scikit-learn - conda-forge::pygrinder >=0.4 - conda-forge::tsdb >=0.2 - conda-forge::protobuf <=4.21.12 - - conda-forge::matplotlib - pytorch::pytorch >=1.10.0 # optional - pyg::pyg - - pyg::pytorch-scatter - pyg::pytorch-sparse + - pyg::pytorch-scatter # test - conda-forge::pytest-cov From 0a044d29452789583a0164a42d2638c718a9bacf Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sun, 31 Mar 2024 01:59:38 +0800 Subject: [PATCH 11/12] docs: update the docs; --- README.md | 18 ++++++++++--- docs/index.rst | 45 ++++++++++++++++++++------------- docs/pypots.imputation.rst | 45 +++++++++++++++++++++++++++++++++ docs/references.bib | 52 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 139 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index af6a91e5..7df93dfe 100644 --- a/README.md +++ b/README.md @@ -174,7 +174,7 @@ dataset = {"X": X} # X for model input print(X.shape) # (11988, 48, 37), 11988 samples and each sample has 48 time steps, 37 features # Model training. This is PyPOTS showtime. -saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, d_inner=128, n_heads=4, d_k=64, d_v=64, dropout=0.1, epochs=10) +saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, d_ffn=128, n_heads=4, d_k=64, d_v=64, dropout=0.1, epochs=10) # Here I use the whole dataset as the training set because ground truth is not visible to the model, you can also split it into train/val/test sets saits.fit(dataset) imputation = saits.impute(dataset) # impute the originally-missing values and artificially-missing values @@ -197,14 +197,21 @@ This functionality is implemented with the [Microsoft NNI](https://github.com/mi | **Type** | **Abbr.** | **Full name of the algorithm/model** | **Year** | | Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 | | Neural Net | Transformer | Attention is All you Need [^2];
Self-Attention-based Imputation for Time Series [^1];
Note: proposed in [^2], and re-implemented as an imputation model in [^1]. | 2017 | +| Neural Net | Crossformer | Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting [^16] | 2023 | | Neural Net | TimesNet | Temporal 2D-Variation Modeling for General Time Series Analysis [^14] | 2023 | -| Neural Net | Autoformer | Decomposition transformers with auto-correlation for long-term series forecasting [^15] | 2021 | +| Neural Net | PatchTST | A Time Series is Worth 64 Words: Long-Term Forecasting with Transformers [^18] | 2023 | +| Neural Net | DLinear | Are Transformers Effective for Time Series Forecasting? [^17] | 2023 | +| Neural Net | ETSformer | Exponential Smoothing Transformers for Time-series Forecasting [^19] | 2023 | +| Neural Net | FEDformer | Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting [^20] | 2022 | +| Neural Net | Autoformer | Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting [^15] | 2021 | | Neural Net | CSDI | Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation [^12] | 2021 | | Neural Net | US-GAN | Unsupervised GAN for Multivariate Time Series Imputation [^10] | 2021 | | Neural Net | GP-VAE | Gaussian Process Variational Autoencoder [^11] | 2020 | | Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 | | Neural Net | M-RNN | Multi-directional Recurrent Neural Network [^9] | 2019 | -| Naive | LOCF | Last Observation Carried Forward | - | +| Naive | LOCF/NOCB | Last Observation Carried Forward / Next Observation Carried Backward | - | +| Naive | Median | Median Value Imputation | - | +| Naive | Mean | Mean Value Imputation | - | | ***`Classification`*** | 🚥 | 🚥 | 🚥 | | **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | | Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 | @@ -320,6 +327,11 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together [^13]: Rubin, D. B. (1976). [Inference and missing data](https://academic.oup.com/biomet/article-abstract/63/3/581/270932). *Biometrika*. [^14]: Wu, H., Hu, T., Liu, Y., Zhou, H., Wang, J., & Long, M. (2023). [TimesNet: Temporal 2d-variation modeling for general time series analysis](https://openreview.net/forum?id=ju_Uqw384Oq). *ICLR 2023* [^15]: Wu, H., Xu, J., Wang, J., & Long, M. (2021). [Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting](https://proceedings.neurips.cc/paper/2021/hash/bcc0d400288793e8bdcd7c19a8ac0c2b-Abstract.html). *NeurIPS 2021*. +[^16]: Zhang, Y., & Yan, J. (2023). [Crossformer: Transformer utilizing cross-dimension dependency for multivariate time series forecasting](https://openreview.net/forum?id=vSVLM2j9eie). *ICLR 2023*. +[^17]: Zeng, A., Chen, M., Zhang, L., & Xu, Q. (2023). [Are transformers effective for time series forecasting?](https://ojs.aaai.org/index.php/AAAI/article/view/26317). *AAAI 2023* +[^18]: Nie, Y., Nguyen, N. H., Sinthong, P., & Kalagnanam, J. (2023). [A time series is worth 64 words: Long-term forecasting with transformers](https://openreview.net/forum?id=Jbdc0vTOcol). *ICLR 2023* +[^19]: Woo, G., Liu, C., Sahoo, D., Kumar, A., & Hoi, S. (2023). [ETSformer: Exponential Smoothing Transformers for Time-series Forecasting](https://openreview.net/forum?id=5m_3whfo483). *ICLR 2023* +[^20]: Zhou, T., Ma, Z., Wen, Q., Wang, X., Sun, L., & Jin, R. (2022). [FEDformer: Frequency enhanced decomposed transformer for long-term series forecasting](https://proceedings.mlr.press/v162/zhou22g.html). *ICML 2022*.
diff --git a/docs/index.rst b/docs/index.rst index 32e7eb54..63d12f8f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -170,24 +170,33 @@ PyPOTS supports imputation, classification, clustering, and forecasting tasks on 🌟 Since **v0.2**, all neural-network models in PyPOTS has got hyperparameter-optimization support. This functionality is implemented with the `Microsoft NNI `_ framework. -============================== ================ ======================================================================================== ====== ========= -Task Type Algorithm Year Reference -============================== ================ ======================================================================================== ====== ========= -Imputation Neural Net SAITS (Self-Attention-based Imputation for Time Series) 2023 :cite:`du2023SAITS` -Imputation Neural Net Transformer 2017 :cite:`vaswani2017Transformer`, :cite:`du2023SAITS` -Imputation Neural Net TimesNet 2023 :cite:`wu2023timesnet` -Imputation Neural Net US-GAN (Unsupervised GAN for Multivariate Time Series Imputation) 2021 :cite:`miao2021SSGAN` -Imputation Neural Net CSDI (Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation) 2021 :cite:`tashiro2021csdi` -Imputation Neural Net GP-VAE (Gaussian Process Variational Autoencoder) 2020 :cite:`fortuin2020GPVAEDeep` -Imputation, Classification Neural Net BRITS (Bidirectional Recurrent Imputation for Time Series) 2018 :cite:`cao2018BRITS` -Imputation Neural Net M-RNN (Multi-directional Recurrent Neural Network) 2019 :cite:`yoon2019MRNN` -Imputation Naive LOCF (Last Observation Carried Forward) / / -Classification Neural Net GRU-D 2018 :cite:`che2018GRUD` -Classification Neural Net Raindrop 2022 :cite:`zhang2022Raindrop` -Clustering Neural Net CRLI (Clustering Representation Learning on Incomplete time-series data) 2021 :cite:`ma2021CRLI` -Clustering Neural Net VaDER (Variational Deep Embedding with Recurrence) 2019 :cite:`dejong2019VaDER` -Forecasting Probabilistic BTTF (Bayesian Temporal Tensor Factorization) 2021 :cite:`chen2021BTMF` -============================== ================ ======================================================================================== ====== ========= +============================== ================ ========================================================================================================= ====== ========= +Task Type Algorithm Year Reference +============================== ================ ========================================================================================================= ====== ========= +Imputation Neural Net SAITS (Self-Attention-based Imputation for Time Series) 2023 :cite:`du2023SAITS` +Imputation Neural Net Transformer 2017 :cite:`vaswani2017Transformer`, :cite:`du2023SAITS` +Imputation Neural Net Crossformer (Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting) 2023 :cite:`nie2023patchtst` +Imputation Neural Net TimesNet (Temporal 2D-Variation Modeling for General Time Series Analysis) 2023 :cite:`wu2023timesnet` +Imputation Neural Net PatchTST (A Time Series is Worth 64 Words: Long-Term Forecasting with Transformers) 2023 :cite:`nie2023patchtst` +Imputation Neural Net DLinear (Are transformers effective for time series forecasting?) 2023 :cite:`zeng2023dlinear` +Imputation Neural Net ETSformer (Exponential Smoothing Transformers for Time-series Forecasting) 2023 :cite:`woo2023etsformer` +Imputation Neural Net FEDformer (Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting) 2022 :cite:`zhou2022fedformer` +Imputation Neural Net Autoformer (Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting) 2021 :cite:`wu2021autoformer` +Imputation Neural Net US-GAN (Unsupervised GAN for Multivariate Time Series Imputation) 2021 :cite:`miao2021SSGAN` +Imputation Neural Net CSDI (Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation) 2021 :cite:`tashiro2021csdi` +Imputation Neural Net GP-VAE (Gaussian Process Variational Autoencoder) 2020 :cite:`fortuin2020GPVAEDeep` +Imputation, Classification Neural Net BRITS (Bidirectional Recurrent Imputation for Time Series) 2018 :cite:`cao2018BRITS` +Imputation Neural Net M-RNN (Multi-directional Recurrent Neural Network) 2019 :cite:`yoon2019MRNN` +Imputation Naive LOCF (Last Observation Carried Forward) / / +Imputation Naive NOCB (Next Observation Carried Backward) / / +Imputation Naive Median (Median Value Imputation) / / +Imputation Naive Mean (Mean Value Imputation) / / +Classification Neural Net GRU-D 2018 :cite:`che2018GRUD` +Classification Neural Net Raindrop 2022 :cite:`zhang2022Raindrop` +Clustering Neural Net CRLI (Clustering Representation Learning on Incomplete time-series data) 2021 :cite:`ma2021CRLI` +Clustering Neural Net VaDER (Variational Deep Embedding with Recurrence) 2019 :cite:`dejong2019VaDER` +Forecasting Probabilistic BTTF (Bayesian Temporal Tensor Factorization) 2021 :cite:`chen2021BTMF` +============================== ================ ========================================================================================================= ====== ========= ❖ Citing PyPOTS diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst index 8aa4742e..176322ef 100644 --- a/docs/pypots.imputation.rst +++ b/docs/pypots.imputation.rst @@ -19,6 +19,15 @@ pypots.imputation.transformer :show-inheritance: :inherited-members: +pypots.imputation.crossformer +------------------------------ + +.. automodule:: pypots.imputation.crossformer + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.imputation.timesnet ------------------------------ @@ -28,6 +37,42 @@ pypots.imputation.timesnet :show-inheritance: :inherited-members: +pypots.imputation.patchtst +------------------------------ + +.. automodule:: pypots.imputation.patchtst + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + +pypots.imputation.dlinear +------------------------------ + +.. automodule:: pypots.imputation.dlinear + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + +pypots.imputation.etsformer +------------------------------ + +.. automodule:: pypots.imputation.etsformer + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + +pypots.imputation.fedformer +------------------------------ + +.. automodule:: pypots.imputation.fedformer + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.imputation.autoformer ------------------------------ diff --git a/docs/references.bib b/docs/references.bib index 52f1c59e..fd0cd53f 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -484,3 +484,55 @@ @inproceedings{wu2021autoformer volume = {34}, year = {2021} } + +@inproceedings{zhang2023crossformer, +title={Crossformer: Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting}, +author={Yunhao Zhang and Junchi Yan}, +booktitle={The Eleventh International Conference on Learning Representations}, +year={2023}, +url={https://openreview.net/forum?id=vSVLM2j9eie} +} + +@inproceedings{zeng2023dlinear, +title={Are Transformers Effective for Time Series Forecasting?}, +volume={37}, +url={https://ojs.aaai.org/index.php/AAAI/article/view/26317}, +DOI={10.1609/aaai.v37i9.26317}, +number={9}, +journal={Proceedings of the AAAI Conference on Artificial Intelligence}, +author={Zeng, Ailing and Chen, Muxi and Zhang, Lei and Xu, Qiang}, +year={2023}, +month={Jun.}, +pages={11121-11128} +} + +@inproceedings{nie2023patchtst, +title={A Time Series is Worth 64 Words: Long-term Forecasting with Transformers}, +author={Yuqi Nie and Nam H Nguyen and Phanwadee Sinthong and Jayant Kalagnanam}, +booktitle={The Eleventh International Conference on Learning Representations }, +year={2023}, +url={https://openreview.net/forum?id=Jbdc0vTOcol} +} + +@inproceedings{woo2023etsformer, +title={{ETS}former: Exponential Smoothing Transformers for Time-series Forecasting}, +author={Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi}, +booktitle={The Eleventh International Conference on Learning Representations}, +year={2023}, +url={https://openreview.net/forum?id=5m_3whfo483} +} + +@inproceedings{zhou2022fedformer, +title = {{FED}former: Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting}, +author = {Zhou, Tian and Ma, Ziqing and Wen, Qingsong and Wang, Xue and Sun, Liang and Jin, Rong}, +booktitle = {Proceedings of the 39th International Conference on Machine Learning}, +pages = {27268--27286}, +year = {2022}, +editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan}, +volume = {162}, +series = {Proceedings of Machine Learning Research}, +month = {17--23 Jul}, +publisher = {PMLR}, +pdf = {https://proceedings.mlr.press/v162/zhou22g/zhou22g.pdf}, +url = {https://proceedings.mlr.press/v162/zhou22g.html}, +} From 0fbb5ee38ac47db454deed54e301e80950339e2c Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 1 Apr 2024 15:33:41 +0800 Subject: [PATCH 12/12] test: make the FEDformer model for testing smaller to speed up; --- tests/imputation/fedformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/imputation/fedformer.py b/tests/imputation/fedformer.py index 5d32fc6a..7c8b4b04 100644 --- a/tests/imputation/fedformer.py +++ b/tests/imputation/fedformer.py @@ -45,10 +45,10 @@ class TestFEDformer(unittest.TestCase): fedformer = FEDformer( DATA["n_steps"], DATA["n_features"], - n_layers=2, + n_layers=1, n_heads=2, - d_model=128, - d_ffn=256, + d_model=64, + d_ffn=64, moving_avg_window_size=3, dropout=0, version="Wavelets", # TODO: Fourier version does not work