diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 4272f36e..8136b0f8 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -37,6 +37,7 @@ from .stemgnn import StemGNN from .imputeformer import ImputeFormer from .timemixer import TimeMixer +from .moderntcn import ModernTCN # naive imputation methods from .locf import LOCF @@ -77,6 +78,7 @@ "StemGNN", "ImputeFormer", "TimeMixer", + "ModernTCN", # naive imputation methods "LOCF", "Mean", diff --git a/pypots/imputation/moderntcn/__init__.py b/pypots/imputation/moderntcn/__init__.py new file mode 100644 index 00000000..f82da16b --- /dev/null +++ b/pypots/imputation/moderntcn/__init__.py @@ -0,0 +1,24 @@ +""" +The package of the partially-observed time-series imputation model ModernTCN. + +Refer to the paper +`Donghao Luo, and Xue Wang. +ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis. +In The Twelfth International Conference on Learning Representations. 2024. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/luodhhh/ModernTCN + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import ModernTCN + +__all__ = [ + "ModernTCN", +] diff --git a/pypots/imputation/moderntcn/core.py b/pypots/imputation/moderntcn/core.py new file mode 100644 index 00000000..3ca8e8f5 --- /dev/null +++ b/pypots/imputation/moderntcn/core.py @@ -0,0 +1,95 @@ +""" +The core wrapper assembles the submodules of ModernTCN imputation model +and takes over the forward progress of the algorithm. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from ...nn.functional import nonstationary_norm, nonstationary_denorm +from ...nn.modules.moderntcn import BackboneModernTCN +from ...nn.modules.patchtst.layers import FlattenHead +from ...utils.metrics import calc_mse + + +class _ModernTCN(nn.Module): + def __init__( + self, + n_steps, + n_features, + patch_size, + patch_stride, + downsampling_ratio, + ffn_ratio, + num_blocks: list, + large_size: list, + small_size: list, + dims: list, + small_kernel_merged: bool = False, + backbone_dropout: float = 0.1, + head_dropout: float = 0.1, + use_multi_scale: bool = True, + individual: bool = False, + apply_nonstationary_norm: bool = False, + ): + super().__init__() + + self.apply_nonstationary_norm = apply_nonstationary_norm + + self.backbone = BackboneModernTCN( + n_steps, + n_features, + n_features, + patch_size, + patch_stride, + downsampling_ratio, + ffn_ratio, + num_blocks, + large_size, + small_size, + dims, + small_kernel_merged, + backbone_dropout, + head_dropout, + use_multi_scale, + individual, + ) + + # for the imputation task, the output dim is the same as input dim + self.projection = FlattenHead( + self.backbone.head_nf, + n_steps, + n_features, + head_dropout, + individual, + ) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, missing_mask = inputs["X"], inputs["missing_mask"] + + if self.apply_nonstationary_norm: + # Normalization from Non-stationary Transformer + X, means, stdev = nonstationary_norm(X, missing_mask) + + in_X = X.permute(0, 2, 1) + in_X = self.backbone(in_X) + reconstruction = self.projection(in_X) + reconstruction = reconstruction.permute(0, 2, 1) + + if self.apply_nonstationary_norm: + # De-Normalization from Non-stationary Transformer + reconstruction = nonstationary_denorm(reconstruction, means, stdev) + + imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction + results = { + "imputed_data": imputed_data, + } + + # if in training mode, return results with losses + if training: + loss = calc_mse(reconstruction, inputs["X_ori"], inputs["indicating_mask"]) + results["loss"] = loss + + return results diff --git a/pypots/imputation/moderntcn/data.py b/pypots/imputation/moderntcn/data.py new file mode 100644 index 00000000..c296728a --- /dev/null +++ b/pypots/imputation/moderntcn/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for ModernTCN. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForModernTCN(DatasetForSAITS): + """Actually ModernTCN uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_y: bool, + file_type: str = "hdf5", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_y, file_type, rate) diff --git a/pypots/imputation/moderntcn/model.py b/pypots/imputation/moderntcn/model.py new file mode 100644 index 00000000..2efb3fed --- /dev/null +++ b/pypots/imputation/moderntcn/model.py @@ -0,0 +1,342 @@ +""" +The implementation of ModernTCN for the partially-observed time-series imputation task. + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .core import _ModernTCN +from .data import DatasetForModernTCN +from ..base import BaseNNImputer +from ...data.checking import key_in_data_set +from ...data.dataset import BaseDataset +from ...optim.adam import Adam +from ...optim.base import Optimizer + + +class ModernTCN(BaseNNImputer): + """The PyTorch implementation of the ModernTCN model. + ModernTCN is originally proposed by Luo et al. in :cite:`luo2024moderntcn`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + patch_size : + The size of the patch for the patching mechanism. + + patch_stride : + The stride for the patching mechanism. + + downsampling_ratio : + The downsampling ratio for the downsampling mechanism. + + ffn_ratio : + The ratio for the feed-forward neural network in the model. + + num_blocks : + The number of blocks for the model. It should be a list of integers. + + large_size : + The size of the large kernel. It should be a list of odd integers. + + small_size : + The size of the small kernel. It should be a list of odd integers. + + dims : + The dimensions for the model. It should be a list of integers. + + small_kernel_merged : + Whether the small kernel is merged. + + backbone_dropout : + The dropout rate for the backbone of the model. + + head_dropout : + The dropout rate for the head of the model. + + use_multi_scale : + Whether to use multi-scale fusing. + + individual : + Whether to make a linear layer for each variate/channel/feature individually. + + apply_nonstationary_norm : + Whether to apply non-stationary normalization. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + verbose : + Whether to print out the training logs during the training process. + """ + + def __init__( + self, + n_steps: int, + n_features: int, + patch_size: int, + patch_stride: int, + downsampling_ratio: float, + ffn_ratio: float, + num_blocks: list, + large_size: list, + small_size: list, + dims: list, + small_kernel_merged: bool = False, + backbone_dropout: float = 0.1, + head_dropout: float = 0.1, + use_multi_scale: bool = True, + individual: bool = False, + apply_nonstationary_norm: bool = False, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + verbose: bool = True, + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + verbose, + ) + assert ( + len(num_blocks) == len(dims) == len(large_size) == len(small_size) + ), "The length of num_blocks, dims, large_size, and small_size should be the same." + + self.n_steps = n_steps + self.n_features = n_features + + # set up the model + self.model = _ModernTCN( + n_steps, + n_features, + patch_size, + patch_stride, + downsampling_ratio, + ffn_ratio, + num_blocks, + large_size, + small_size, + dims, + small_kernel_merged, + backbone_dropout, + head_dropout, + use_multi_scale, + individual, + apply_nonstationary_norm, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForModernTCN( + train_set, return_X_ori=False, return_y=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForModernTCN( + val_set, return_X_ori=True, return_y=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if test_set is a path string. + + Returns + ------- + file_type : + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, + return_X_ori=False, + return_X_pred=False, + return_y=False, + file_type=file_type, + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + test_set : + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (n_steps), n_features], + Imputed data. + """ + + result_dict = self.predict(test_set, file_type=file_type) + return result_dict["imputation"] diff --git a/pypots/nn/modules/moderntcn/__init__.py b/pypots/nn/modules/moderntcn/__init__.py new file mode 100644 index 00000000..bfca4e48 --- /dev/null +++ b/pypots/nn/modules/moderntcn/__init__.py @@ -0,0 +1,24 @@ +""" +The package including the modules of ModernTCN. + +Refer to the paper +`Donghao Luo, and Xue Wang. +ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis. +In The Twelfth International Conference on Learning Representations. 2024. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/luodhhh/ModernTCN + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .backbone import BackboneModernTCN + +__all__ = [ + "BackboneModernTCN", +] diff --git a/pypots/nn/modules/moderntcn/backbone.py b/pypots/nn/modules/moderntcn/backbone.py new file mode 100644 index 00000000..a9e3b388 --- /dev/null +++ b/pypots/nn/modules/moderntcn/backbone.py @@ -0,0 +1,184 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn + +from .layers import Stage +from ..patchtst.layers import FlattenHead + + +class BackboneModernTCN(nn.Module): + def __init__( + self, + n_steps, + n_features, + n_predict_features, + patch_size, + patch_stride, + downsampling_ratio, + ffn_ratio, + num_blocks: list, + large_size: list, + small_size: list, + dims: list, + small_kernel_merged: bool = False, + backbone_dropout: float = 0.1, + head_dropout: float = 0.1, + use_multi_scale: bool = True, + individual: bool = False, + freq: str = "h", + ): + super().__init__() + + # stem layer & down sampling layers + self.downsample_layers = nn.ModuleList() + stem = nn.Linear(patch_size, dims[0]) + self.downsample_layers.append(stem) + + self.num_stage = len(num_blocks) + if self.num_stage > 1: + for i in range(self.num_stage - 1): + downsample_layer = nn.Sequential( + nn.BatchNorm1d(dims[i]), + nn.Conv1d( + dims[i], + dims[i + 1], + kernel_size=downsampling_ratio, + stride=downsampling_ratio, + ), + ) + self.downsample_layers.append(downsample_layer) + + self.patch_size = patch_size + self.patch_stride = patch_stride + self.downsample_ratio = downsampling_ratio + + if freq == "h": + time_feature_num = 4 + elif freq == "t": + time_feature_num = 5 + else: + raise NotImplementedError("time_feature_num should be 4 or 5") + + self.te_patch = nn.Sequential( + nn.Conv1d( + time_feature_num, + time_feature_num, + kernel_size=patch_size, + stride=patch_stride, + groups=time_feature_num, + ), + nn.Conv1d(time_feature_num, dims[0], kernel_size=1, stride=1, groups=1), + nn.BatchNorm1d(dims[0]), + ) + + # backbone + self.stages = nn.ModuleList() + for stage_idx in range(self.num_stage): + layer = Stage( + ffn_ratio, + num_blocks[stage_idx], + large_size[stage_idx], + small_size[stage_idx], + dmodel=dims[stage_idx], + nvars=n_features, + small_kernel_merged=small_kernel_merged, + drop=backbone_dropout, + ) + self.stages.append(layer) + + # Multi scale fusing + self.use_multi_scale = use_multi_scale + self.up_sample_ratio = downsampling_ratio + + self.lat_layer = nn.ModuleList() + self.smooth_layer = nn.ModuleList() + self.up_sample_conv = nn.ModuleList() + for i in range(self.num_stage): + align_dim = dims[-1] + lat = nn.Conv1d(dims[i], align_dim, kernel_size=1, stride=1) + self.lat_layer.append(lat) + smooth = nn.Conv1d(align_dim, align_dim, kernel_size=3, stride=1, padding=1) + self.smooth_layer.append(smooth) + up_conv = nn.Sequential( + nn.ConvTranspose1d( + align_dim, + align_dim, + kernel_size=self.up_sample_ratio, + stride=self.up_sample_ratio, + ), + nn.BatchNorm1d(align_dim), + ) + self.up_sample_conv.append(up_conv) + + # head + patch_num = n_steps // patch_stride + + self.n_features = n_features + self.individual = individual + d_model = dims[self.num_stage - 1] + + if use_multi_scale: + self.head_nf = d_model * patch_num + self.head = FlattenHead( + self.head_nf, + n_predict_features, + n_features, + head_dropout, + individual, + ) + else: + if patch_num % pow(downsampling_ratio, (self.num_stage - 1)) == 0: + self.head_nf = ( + d_model * patch_num // pow(downsampling_ratio, (self.num_stage - 1)) + ) + else: + self.head_nf = d_model * ( + patch_num // pow(downsampling_ratio, (self.num_stage - 1)) + 1 + ) + + self.head = FlattenHead( + self.head_nf, + n_predict_features, + n_features, + head_dropout, + individual, + ) + + def structural_reparam(self): + for m in self.modules(): + if hasattr(m, "merge_kernel"): + m.merge_kernel() + + def forward(self, x): + x = x.unsqueeze(-2) + + for i in range(self.num_stage): + B, M, D, N = x.shape + x = x.reshape(B * M, D, N) + + if i == 0: + if self.patch_size != self.patch_stride: + pad_len = self.patch_size - self.patch_stride + pad = x[:, :, -1:].repeat(1, 1, pad_len) + x = torch.cat([x, pad], dim=-1) + x = x.reshape(B, M, 1, -1).squeeze(-2) + x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride) + x = self.downsample_layers[i](x) + x = x.permute(0, 1, 3, 2) + + else: + if N % self.downsample_ratio != 0: + pad_len = self.downsample_ratio - (N % self.downsample_ratio) + x = torch.cat([x, x[:, :, -pad_len:]], dim=-1) + x = self.downsample_layers[i](x) + _, D_, N_ = x.shape + x = x.reshape(B, M, D_, N_) + + x = self.stages[i](x) + return x diff --git a/pypots/nn/modules/moderntcn/layers.py b/pypots/nn/modules/moderntcn/layers.py new file mode 100644 index 00000000..b7c21058 --- /dev/null +++ b/pypots/nn/modules/moderntcn/layers.py @@ -0,0 +1,328 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +from torch import nn + + +def get_conv1d( + in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias +): + return nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + +def get_bn(channels): + return nn.BatchNorm1d(channels) + + +def conv_bn( + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups, + dilation=1, + bias=False, +): + if padding is None: + padding = kernel_size // 2 + result = nn.Sequential() + result.add_module( + "conv", + get_conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ), + ) + result.add_module("bn", get_bn(out_channels)) + return result + + +def fuse_bn(conv, bn): + kernel = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-6): + super().__init__() + self.norm = nn.LayerNorm(channels, eps=eps) + + def forward(self, x): + B, M, D, N = x.shape + x = x.permute(0, 1, 3, 2) + x = x.reshape(B * M, N, D) + x = self.norm(x) + x = x.reshape(B, M, N, D) + x = x.permute(0, 1, 3, 2) + return x + + +class ReparamLargeKernelConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + groups, + small_kernel, + small_kernel_merged=False, + nvars=7, + ): + super().__init__() + self.kernel_size = kernel_size + self.small_kernel = small_kernel + # We assume the conv does not change the feature map size, so padding = k//2. + # Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly. + padding = kernel_size // 2 + if small_kernel_merged: + self.lkb_reparam = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=1, + groups=groups, + bias=True, + ) + else: + self.lkb_origin = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=1, + groups=groups, + bias=False, + ) + if small_kernel is not None: + assert ( + small_kernel <= kernel_size + ), "The kernel size for re-param cannot be larger than the large kernel!" + self.small_conv = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=small_kernel, + stride=stride, + padding=small_kernel // 2, + groups=groups, + dilation=1, + bias=False, + ) + + def forward(self, inputs): + if hasattr(self, "lkb_reparam"): + out = self.lkb_reparam(inputs) + else: + out = self.lkb_origin(inputs) + if hasattr(self, "small_conv"): + out += self.small_conv(inputs) + return out + + def PaddingTwoEdge1d(self, x, pad_length_left, pad_length_right, pad_values=0): + D_out, D_in, ks = x.shape + if pad_values == 0: + pad_left = torch.zeros(D_out, D_in, pad_length_left) + pad_right = torch.zeros(D_out, D_in, pad_length_right) + else: + pad_left = torch.ones(D_out, D_in, pad_length_left) * pad_values + pad_right = torch.ones(D_out, D_in, pad_length_right) * pad_values + x = torch.cat([pad_left, x], dims=-1) + x = torch.cat([x, pad_right], dims=-1) + return x + + def get_equivalent_kernel_bias(self): + eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) + if hasattr(self, "small_conv"): + small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn) + eq_b += small_b + eq_k += self.PaddingTwoEdge1d( + small_k, + (self.kernel_size - self.small_kernel) // 2, + (self.kernel_size - self.small_kernel) // 2, + 0, + ) + return eq_k, eq_b + + def merge_kernel(self): + eq_k, eq_b = self.get_equivalent_kernel_bias() + self.lkb_reparam = nn.Conv1d( + in_channels=self.lkb_origin.conv.in_channels, + out_channels=self.lkb_origin.conv.out_channels, + kernel_size=self.lkb_origin.conv.kernel_size, + stride=self.lkb_origin.conv.stride, + padding=self.lkb_origin.conv.padding, + dilation=self.lkb_origin.conv.dilation, + groups=self.lkb_origin.conv.groups, + bias=True, + ) + self.lkb_reparam.weight.data = eq_k + self.lkb_reparam.bias.data = eq_b + self.__delattr__("lkb_origin") + if hasattr(self, "small_conv"): + self.__delattr__("small_conv") + + +class Block(nn.Module): + def __init__( + self, + large_size, + small_size, + dmodel, + dff, + nvars, + small_kernel_merged=False, + drop=0.1, + ): + super().__init__() + self.dw = ReparamLargeKernelConv( + in_channels=nvars * dmodel, + out_channels=nvars * dmodel, + kernel_size=large_size, + stride=1, + groups=nvars * dmodel, + small_kernel=small_size, + small_kernel_merged=small_kernel_merged, + nvars=nvars, + ) + self.norm = nn.BatchNorm1d(dmodel) + + # convffn1 + self.ffn1pw1 = nn.Conv1d( + in_channels=nvars * dmodel, + out_channels=nvars * dff, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=nvars, + ) + self.ffn1act = nn.GELU() + self.ffn1pw2 = nn.Conv1d( + in_channels=nvars * dff, + out_channels=nvars * dmodel, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=nvars, + ) + self.ffn1drop1 = nn.Dropout(drop) + self.ffn1drop2 = nn.Dropout(drop) + + # convffn2 + self.ffn2pw1 = nn.Conv1d( + in_channels=nvars * dmodel, + out_channels=nvars * dff, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=dmodel, + ) + self.ffn2act = nn.GELU() + self.ffn2pw2 = nn.Conv1d( + in_channels=nvars * dff, + out_channels=nvars * dmodel, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=dmodel, + ) + self.ffn2drop1 = nn.Dropout(drop) + self.ffn2drop2 = nn.Dropout(drop) + + self.ffn_ratio = dff // dmodel + + def forward(self, x): + input = x + B, M, D, N = x.shape + x = x.reshape(B, M * D, N) + x = self.dw(x) + x = x.reshape(B, M, D, N) + x = x.reshape(B * M, D, N) + x = self.norm(x) + x = x.reshape(B, M, D, N) + x = x.reshape(B, M * D, N) + + x = self.ffn1drop1(self.ffn1pw1(x)) + x = self.ffn1act(x) + x = self.ffn1drop2(self.ffn1pw2(x)) + x = x.reshape(B, M, D, N) + + x = x.permute(0, 2, 1, 3) + x = x.reshape(B, D * M, N) + x = self.ffn2drop1(self.ffn2pw1(x)) + x = self.ffn2act(x) + x = self.ffn2drop2(self.ffn2pw2(x)) + x = x.reshape(B, D, M, N) + x = x.permute(0, 2, 1, 3) + + x = input + x + return x + + +class Stage(nn.Module): + def __init__( + self, + ffn_ratio, + num_blocks, + large_size, + small_size, + dmodel, + nvars, + small_kernel_merged=False, + drop=0.1, + ): + super().__init__() + d_ffn = dmodel * ffn_ratio + blks = [] + for i in range(num_blocks): + blk = Block( + large_size=large_size, + small_size=small_size, + dmodel=dmodel, + dff=d_ffn, + nvars=nvars, + small_kernel_merged=small_kernel_merged, + drop=drop, + ) + blks.append(blk) + self.blocks = nn.ModuleList(blks) + + def forward(self, x): + for blk in self.blocks: + x = blk(x) + + return x diff --git a/pypots/nn/modules/patchtst/layers.py b/pypots/nn/modules/patchtst/layers.py index 083c368a..3990954b 100644 --- a/pypots/nn/modules/patchtst/layers.py +++ b/pypots/nn/modules/patchtst/layers.py @@ -106,13 +106,13 @@ def __init__( head_dim = d_model * n_patches self.individual = individual - self.n_vars = n_features + self.n_features = n_features if self.individual: self.linears = nn.ModuleList() self.dropouts = nn.ModuleList() self.flattens = nn.ModuleList() - for i in range(self.n_vars): + for i in range(self.n_features): self.flattens.append(nn.Flatten(start_dim=-2)) self.linears.append(nn.Linear(head_dim, n_steps_forecast)) self.dropouts.append(nn.Dropout(head_dropout)) @@ -128,7 +128,7 @@ def forward(self, x): """ if self.individual: x_out = [] - for i in range(self.n_vars): + for i in range(self.n_features): z = self.flattens[i](x[:, i, :, :]) # z: [bs x d_model * num_patch] z = self.linears[i](z) # z: [bs x forecast_len] z = self.dropouts[i](z) @@ -139,3 +139,46 @@ def forward(self, x): x = self.dropout(x) x = self.linear(x) # x: [bs x nvars x forecast_len] return x.transpose(2, 1) # [bs x forecast_len x nvars] + + +class FlattenHead(nn.Module): + def __init__( + self, + d_input, + d_output, + n_features, + head_dropout=0, + individual=False, + ): + super().__init__() + + self.individual = individual + self.n_features = n_features + + if self.individual: + self.linears = nn.ModuleList() + self.dropouts = nn.ModuleList() + self.flattens = nn.ModuleList() + for i in range(self.n_features): + self.flattens.append(nn.Flatten(start_dim=-2)) + self.linears.append(nn.Linear(d_input, d_output)) + self.dropouts.append(nn.Dropout(head_dropout)) + else: + self.flatten = nn.Flatten(start_dim=-2) + self.linear = nn.Linear(d_input, d_output) + self.dropout = nn.Dropout(head_dropout) + + def forward(self, x): + if self.individual: + x_out = [] + for i in range(self.n_features): + z = self.flattens[i](x[:, i, :, :]) # z: [bs x d_model * patch_num] + z = self.linears[i](z) # z: [bs x target_window] + z = self.dropouts[i](z) + x_out.append(z) + x = torch.stack(x_out, dim=1) # x: [bs x nvars x target_window] + else: + x = self.flatten(x) + x = self.linear(x) + x = self.dropout(x) + return x diff --git a/tests/imputation/moderntcn.py b/tests/imputation/moderntcn.py new file mode 100644 index 00000000..33b41269 --- /dev/null +++ b/tests/imputation/moderntcn.py @@ -0,0 +1,137 @@ +""" +Test cases for ModernTCN imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import ModernTCN +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + GENERAL_H5_TRAIN_SET_PATH, + GENERAL_H5_VAL_SET_PATH, + GENERAL_H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestModernTCN(unittest.TestCase): + logger.info("Running tests for an imputation model ModernTCN...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "ModernTCN") + model_save_name = "saved_moderntcn_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a ModernTCN model + moderntcn = ModernTCN( + DATA["n_steps"], + DATA["n_features"], + patch_size=3, + patch_stride=2, + downsampling_ratio=2, + ffn_ratio=1, + num_blocks=[1], + large_size=[5], + small_size=[3], + dims=[32], + small_kernel_merged=False, + backbone_dropout=0.1, + head_dropout=0.1, + use_multi_scale=False, + individual=False, + apply_nonstationary_norm=False, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_0_fit(self): + self.moderntcn.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_1_impute(self): + imputation_results = self.moderntcn.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"ModernTCN test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_2_parameters(self): + assert hasattr(self.moderntcn, "model") and self.moderntcn.model is not None + + assert ( + hasattr(self.moderntcn, "optimizer") + and self.moderntcn.optimizer is not None + ) + + assert hasattr(self.moderntcn, "best_loss") + self.assertNotEqual(self.moderntcn.best_loss, float("inf")) + + assert ( + hasattr(self.moderntcn, "best_model_dict") + and self.moderntcn.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.moderntcn) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.moderntcn.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.moderntcn.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_4_lazy_loading(self): + self.moderntcn.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH) + imputation_results = self.moderntcn.predict(GENERAL_H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading ModernTCN test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main()