From 7623272ba1206049f2c8c2eb678b8c4c0b96391f Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 21 Aug 2024 15:51:27 +0800 Subject: [PATCH 1/6] feat: add ModernTCN modules; --- pypots/nn/modules/moderntcn/__init__.py | 24 ++ pypots/nn/modules/moderntcn/backbone.py | 184 +++++++++++++ pypots/nn/modules/moderntcn/layers.py | 328 ++++++++++++++++++++++++ 3 files changed, 536 insertions(+) create mode 100644 pypots/nn/modules/moderntcn/__init__.py create mode 100644 pypots/nn/modules/moderntcn/backbone.py create mode 100644 pypots/nn/modules/moderntcn/layers.py diff --git a/pypots/nn/modules/moderntcn/__init__.py b/pypots/nn/modules/moderntcn/__init__.py new file mode 100644 index 00000000..bfca4e48 --- /dev/null +++ b/pypots/nn/modules/moderntcn/__init__.py @@ -0,0 +1,24 @@ +""" +The package including the modules of ModernTCN. + +Refer to the paper +`Donghao Luo, and Xue Wang. +ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis. +In The Twelfth International Conference on Learning Representations. 2024. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/luodhhh/ModernTCN + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .backbone import BackboneModernTCN + +__all__ = [ + "BackboneModernTCN", +] diff --git a/pypots/nn/modules/moderntcn/backbone.py b/pypots/nn/modules/moderntcn/backbone.py new file mode 100644 index 00000000..a9e3b388 --- /dev/null +++ b/pypots/nn/modules/moderntcn/backbone.py @@ -0,0 +1,184 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn + +from .layers import Stage +from ..patchtst.layers import FlattenHead + + +class BackboneModernTCN(nn.Module): + def __init__( + self, + n_steps, + n_features, + n_predict_features, + patch_size, + patch_stride, + downsampling_ratio, + ffn_ratio, + num_blocks: list, + large_size: list, + small_size: list, + dims: list, + small_kernel_merged: bool = False, + backbone_dropout: float = 0.1, + head_dropout: float = 0.1, + use_multi_scale: bool = True, + individual: bool = False, + freq: str = "h", + ): + super().__init__() + + # stem layer & down sampling layers + self.downsample_layers = nn.ModuleList() + stem = nn.Linear(patch_size, dims[0]) + self.downsample_layers.append(stem) + + self.num_stage = len(num_blocks) + if self.num_stage > 1: + for i in range(self.num_stage - 1): + downsample_layer = nn.Sequential( + nn.BatchNorm1d(dims[i]), + nn.Conv1d( + dims[i], + dims[i + 1], + kernel_size=downsampling_ratio, + stride=downsampling_ratio, + ), + ) + self.downsample_layers.append(downsample_layer) + + self.patch_size = patch_size + self.patch_stride = patch_stride + self.downsample_ratio = downsampling_ratio + + if freq == "h": + time_feature_num = 4 + elif freq == "t": + time_feature_num = 5 + else: + raise NotImplementedError("time_feature_num should be 4 or 5") + + self.te_patch = nn.Sequential( + nn.Conv1d( + time_feature_num, + time_feature_num, + kernel_size=patch_size, + stride=patch_stride, + groups=time_feature_num, + ), + nn.Conv1d(time_feature_num, dims[0], kernel_size=1, stride=1, groups=1), + nn.BatchNorm1d(dims[0]), + ) + + # backbone + self.stages = nn.ModuleList() + for stage_idx in range(self.num_stage): + layer = Stage( + ffn_ratio, + num_blocks[stage_idx], + large_size[stage_idx], + small_size[stage_idx], + dmodel=dims[stage_idx], + nvars=n_features, + small_kernel_merged=small_kernel_merged, + drop=backbone_dropout, + ) + self.stages.append(layer) + + # Multi scale fusing + self.use_multi_scale = use_multi_scale + self.up_sample_ratio = downsampling_ratio + + self.lat_layer = nn.ModuleList() + self.smooth_layer = nn.ModuleList() + self.up_sample_conv = nn.ModuleList() + for i in range(self.num_stage): + align_dim = dims[-1] + lat = nn.Conv1d(dims[i], align_dim, kernel_size=1, stride=1) + self.lat_layer.append(lat) + smooth = nn.Conv1d(align_dim, align_dim, kernel_size=3, stride=1, padding=1) + self.smooth_layer.append(smooth) + up_conv = nn.Sequential( + nn.ConvTranspose1d( + align_dim, + align_dim, + kernel_size=self.up_sample_ratio, + stride=self.up_sample_ratio, + ), + nn.BatchNorm1d(align_dim), + ) + self.up_sample_conv.append(up_conv) + + # head + patch_num = n_steps // patch_stride + + self.n_features = n_features + self.individual = individual + d_model = dims[self.num_stage - 1] + + if use_multi_scale: + self.head_nf = d_model * patch_num + self.head = FlattenHead( + self.head_nf, + n_predict_features, + n_features, + head_dropout, + individual, + ) + else: + if patch_num % pow(downsampling_ratio, (self.num_stage - 1)) == 0: + self.head_nf = ( + d_model * patch_num // pow(downsampling_ratio, (self.num_stage - 1)) + ) + else: + self.head_nf = d_model * ( + patch_num // pow(downsampling_ratio, (self.num_stage - 1)) + 1 + ) + + self.head = FlattenHead( + self.head_nf, + n_predict_features, + n_features, + head_dropout, + individual, + ) + + def structural_reparam(self): + for m in self.modules(): + if hasattr(m, "merge_kernel"): + m.merge_kernel() + + def forward(self, x): + x = x.unsqueeze(-2) + + for i in range(self.num_stage): + B, M, D, N = x.shape + x = x.reshape(B * M, D, N) + + if i == 0: + if self.patch_size != self.patch_stride: + pad_len = self.patch_size - self.patch_stride + pad = x[:, :, -1:].repeat(1, 1, pad_len) + x = torch.cat([x, pad], dim=-1) + x = x.reshape(B, M, 1, -1).squeeze(-2) + x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride) + x = self.downsample_layers[i](x) + x = x.permute(0, 1, 3, 2) + + else: + if N % self.downsample_ratio != 0: + pad_len = self.downsample_ratio - (N % self.downsample_ratio) + x = torch.cat([x, x[:, :, -pad_len:]], dim=-1) + x = self.downsample_layers[i](x) + _, D_, N_ = x.shape + x = x.reshape(B, M, D_, N_) + + x = self.stages[i](x) + return x diff --git a/pypots/nn/modules/moderntcn/layers.py b/pypots/nn/modules/moderntcn/layers.py new file mode 100644 index 00000000..b7c21058 --- /dev/null +++ b/pypots/nn/modules/moderntcn/layers.py @@ -0,0 +1,328 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +from torch import nn + + +def get_conv1d( + in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias +): + return nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + +def get_bn(channels): + return nn.BatchNorm1d(channels) + + +def conv_bn( + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups, + dilation=1, + bias=False, +): + if padding is None: + padding = kernel_size // 2 + result = nn.Sequential() + result.add_module( + "conv", + get_conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ), + ) + result.add_module("bn", get_bn(out_channels)) + return result + + +def fuse_bn(conv, bn): + kernel = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-6): + super().__init__() + self.norm = nn.LayerNorm(channels, eps=eps) + + def forward(self, x): + B, M, D, N = x.shape + x = x.permute(0, 1, 3, 2) + x = x.reshape(B * M, N, D) + x = self.norm(x) + x = x.reshape(B, M, N, D) + x = x.permute(0, 1, 3, 2) + return x + + +class ReparamLargeKernelConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + groups, + small_kernel, + small_kernel_merged=False, + nvars=7, + ): + super().__init__() + self.kernel_size = kernel_size + self.small_kernel = small_kernel + # We assume the conv does not change the feature map size, so padding = k//2. + # Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly. + padding = kernel_size // 2 + if small_kernel_merged: + self.lkb_reparam = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=1, + groups=groups, + bias=True, + ) + else: + self.lkb_origin = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=1, + groups=groups, + bias=False, + ) + if small_kernel is not None: + assert ( + small_kernel <= kernel_size + ), "The kernel size for re-param cannot be larger than the large kernel!" + self.small_conv = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=small_kernel, + stride=stride, + padding=small_kernel // 2, + groups=groups, + dilation=1, + bias=False, + ) + + def forward(self, inputs): + if hasattr(self, "lkb_reparam"): + out = self.lkb_reparam(inputs) + else: + out = self.lkb_origin(inputs) + if hasattr(self, "small_conv"): + out += self.small_conv(inputs) + return out + + def PaddingTwoEdge1d(self, x, pad_length_left, pad_length_right, pad_values=0): + D_out, D_in, ks = x.shape + if pad_values == 0: + pad_left = torch.zeros(D_out, D_in, pad_length_left) + pad_right = torch.zeros(D_out, D_in, pad_length_right) + else: + pad_left = torch.ones(D_out, D_in, pad_length_left) * pad_values + pad_right = torch.ones(D_out, D_in, pad_length_right) * pad_values + x = torch.cat([pad_left, x], dims=-1) + x = torch.cat([x, pad_right], dims=-1) + return x + + def get_equivalent_kernel_bias(self): + eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) + if hasattr(self, "small_conv"): + small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn) + eq_b += small_b + eq_k += self.PaddingTwoEdge1d( + small_k, + (self.kernel_size - self.small_kernel) // 2, + (self.kernel_size - self.small_kernel) // 2, + 0, + ) + return eq_k, eq_b + + def merge_kernel(self): + eq_k, eq_b = self.get_equivalent_kernel_bias() + self.lkb_reparam = nn.Conv1d( + in_channels=self.lkb_origin.conv.in_channels, + out_channels=self.lkb_origin.conv.out_channels, + kernel_size=self.lkb_origin.conv.kernel_size, + stride=self.lkb_origin.conv.stride, + padding=self.lkb_origin.conv.padding, + dilation=self.lkb_origin.conv.dilation, + groups=self.lkb_origin.conv.groups, + bias=True, + ) + self.lkb_reparam.weight.data = eq_k + self.lkb_reparam.bias.data = eq_b + self.__delattr__("lkb_origin") + if hasattr(self, "small_conv"): + self.__delattr__("small_conv") + + +class Block(nn.Module): + def __init__( + self, + large_size, + small_size, + dmodel, + dff, + nvars, + small_kernel_merged=False, + drop=0.1, + ): + super().__init__() + self.dw = ReparamLargeKernelConv( + in_channels=nvars * dmodel, + out_channels=nvars * dmodel, + kernel_size=large_size, + stride=1, + groups=nvars * dmodel, + small_kernel=small_size, + small_kernel_merged=small_kernel_merged, + nvars=nvars, + ) + self.norm = nn.BatchNorm1d(dmodel) + + # convffn1 + self.ffn1pw1 = nn.Conv1d( + in_channels=nvars * dmodel, + out_channels=nvars * dff, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=nvars, + ) + self.ffn1act = nn.GELU() + self.ffn1pw2 = nn.Conv1d( + in_channels=nvars * dff, + out_channels=nvars * dmodel, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=nvars, + ) + self.ffn1drop1 = nn.Dropout(drop) + self.ffn1drop2 = nn.Dropout(drop) + + # convffn2 + self.ffn2pw1 = nn.Conv1d( + in_channels=nvars * dmodel, + out_channels=nvars * dff, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=dmodel, + ) + self.ffn2act = nn.GELU() + self.ffn2pw2 = nn.Conv1d( + in_channels=nvars * dff, + out_channels=nvars * dmodel, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=dmodel, + ) + self.ffn2drop1 = nn.Dropout(drop) + self.ffn2drop2 = nn.Dropout(drop) + + self.ffn_ratio = dff // dmodel + + def forward(self, x): + input = x + B, M, D, N = x.shape + x = x.reshape(B, M * D, N) + x = self.dw(x) + x = x.reshape(B, M, D, N) + x = x.reshape(B * M, D, N) + x = self.norm(x) + x = x.reshape(B, M, D, N) + x = x.reshape(B, M * D, N) + + x = self.ffn1drop1(self.ffn1pw1(x)) + x = self.ffn1act(x) + x = self.ffn1drop2(self.ffn1pw2(x)) + x = x.reshape(B, M, D, N) + + x = x.permute(0, 2, 1, 3) + x = x.reshape(B, D * M, N) + x = self.ffn2drop1(self.ffn2pw1(x)) + x = self.ffn2act(x) + x = self.ffn2drop2(self.ffn2pw2(x)) + x = x.reshape(B, D, M, N) + x = x.permute(0, 2, 1, 3) + + x = input + x + return x + + +class Stage(nn.Module): + def __init__( + self, + ffn_ratio, + num_blocks, + large_size, + small_size, + dmodel, + nvars, + small_kernel_merged=False, + drop=0.1, + ): + super().__init__() + d_ffn = dmodel * ffn_ratio + blks = [] + for i in range(num_blocks): + blk = Block( + large_size=large_size, + small_size=small_size, + dmodel=dmodel, + dff=d_ffn, + nvars=nvars, + small_kernel_merged=small_kernel_merged, + drop=drop, + ) + blks.append(blk) + self.blocks = nn.ModuleList(blks) + + def forward(self, x): + for blk in self.blocks: + x = blk(x) + + return x From e45424765807def433721291b3117396191bb801 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 22 Aug 2024 11:17:20 +0800 Subject: [PATCH 2/6] feat: add FlattenHead; --- pypots/nn/modules/patchtst/layers.py | 49 ++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/pypots/nn/modules/patchtst/layers.py b/pypots/nn/modules/patchtst/layers.py index 083c368a..3990954b 100644 --- a/pypots/nn/modules/patchtst/layers.py +++ b/pypots/nn/modules/patchtst/layers.py @@ -106,13 +106,13 @@ def __init__( head_dim = d_model * n_patches self.individual = individual - self.n_vars = n_features + self.n_features = n_features if self.individual: self.linears = nn.ModuleList() self.dropouts = nn.ModuleList() self.flattens = nn.ModuleList() - for i in range(self.n_vars): + for i in range(self.n_features): self.flattens.append(nn.Flatten(start_dim=-2)) self.linears.append(nn.Linear(head_dim, n_steps_forecast)) self.dropouts.append(nn.Dropout(head_dropout)) @@ -128,7 +128,7 @@ def forward(self, x): """ if self.individual: x_out = [] - for i in range(self.n_vars): + for i in range(self.n_features): z = self.flattens[i](x[:, i, :, :]) # z: [bs x d_model * num_patch] z = self.linears[i](z) # z: [bs x forecast_len] z = self.dropouts[i](z) @@ -139,3 +139,46 @@ def forward(self, x): x = self.dropout(x) x = self.linear(x) # x: [bs x nvars x forecast_len] return x.transpose(2, 1) # [bs x forecast_len x nvars] + + +class FlattenHead(nn.Module): + def __init__( + self, + d_input, + d_output, + n_features, + head_dropout=0, + individual=False, + ): + super().__init__() + + self.individual = individual + self.n_features = n_features + + if self.individual: + self.linears = nn.ModuleList() + self.dropouts = nn.ModuleList() + self.flattens = nn.ModuleList() + for i in range(self.n_features): + self.flattens.append(nn.Flatten(start_dim=-2)) + self.linears.append(nn.Linear(d_input, d_output)) + self.dropouts.append(nn.Dropout(head_dropout)) + else: + self.flatten = nn.Flatten(start_dim=-2) + self.linear = nn.Linear(d_input, d_output) + self.dropout = nn.Dropout(head_dropout) + + def forward(self, x): + if self.individual: + x_out = [] + for i in range(self.n_features): + z = self.flattens[i](x[:, i, :, :]) # z: [bs x d_model * patch_num] + z = self.linears[i](z) # z: [bs x target_window] + z = self.dropouts[i](z) + x_out.append(z) + x = torch.stack(x_out, dim=1) # x: [bs x nvars x target_window] + else: + x = self.flatten(x) + x = self.linear(x) + x = self.dropout(x) + return x From cdfe6dc9e2298dc3a4059e6c784f6ed9c7d7f983 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 22 Aug 2024 15:50:02 +0800 Subject: [PATCH 3/6] feat: implement ModernTCN as an imputation model; --- pypots/imputation/moderntcn/__init__.py | 24 ++ pypots/imputation/moderntcn/core.py | 95 +++++++ pypots/imputation/moderntcn/data.py | 24 ++ pypots/imputation/moderntcn/model.py | 342 ++++++++++++++++++++++++ 4 files changed, 485 insertions(+) create mode 100644 pypots/imputation/moderntcn/__init__.py create mode 100644 pypots/imputation/moderntcn/core.py create mode 100644 pypots/imputation/moderntcn/data.py create mode 100644 pypots/imputation/moderntcn/model.py diff --git a/pypots/imputation/moderntcn/__init__.py b/pypots/imputation/moderntcn/__init__.py new file mode 100644 index 00000000..f82da16b --- /dev/null +++ b/pypots/imputation/moderntcn/__init__.py @@ -0,0 +1,24 @@ +""" +The package of the partially-observed time-series imputation model ModernTCN. + +Refer to the paper +`Donghao Luo, and Xue Wang. +ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis. +In The Twelfth International Conference on Learning Representations. 2024. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/luodhhh/ModernTCN + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import ModernTCN + +__all__ = [ + "ModernTCN", +] diff --git a/pypots/imputation/moderntcn/core.py b/pypots/imputation/moderntcn/core.py new file mode 100644 index 00000000..3ca8e8f5 --- /dev/null +++ b/pypots/imputation/moderntcn/core.py @@ -0,0 +1,95 @@ +""" +The core wrapper assembles the submodules of ModernTCN imputation model +and takes over the forward progress of the algorithm. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from ...nn.functional import nonstationary_norm, nonstationary_denorm +from ...nn.modules.moderntcn import BackboneModernTCN +from ...nn.modules.patchtst.layers import FlattenHead +from ...utils.metrics import calc_mse + + +class _ModernTCN(nn.Module): + def __init__( + self, + n_steps, + n_features, + patch_size, + patch_stride, + downsampling_ratio, + ffn_ratio, + num_blocks: list, + large_size: list, + small_size: list, + dims: list, + small_kernel_merged: bool = False, + backbone_dropout: float = 0.1, + head_dropout: float = 0.1, + use_multi_scale: bool = True, + individual: bool = False, + apply_nonstationary_norm: bool = False, + ): + super().__init__() + + self.apply_nonstationary_norm = apply_nonstationary_norm + + self.backbone = BackboneModernTCN( + n_steps, + n_features, + n_features, + patch_size, + patch_stride, + downsampling_ratio, + ffn_ratio, + num_blocks, + large_size, + small_size, + dims, + small_kernel_merged, + backbone_dropout, + head_dropout, + use_multi_scale, + individual, + ) + + # for the imputation task, the output dim is the same as input dim + self.projection = FlattenHead( + self.backbone.head_nf, + n_steps, + n_features, + head_dropout, + individual, + ) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, missing_mask = inputs["X"], inputs["missing_mask"] + + if self.apply_nonstationary_norm: + # Normalization from Non-stationary Transformer + X, means, stdev = nonstationary_norm(X, missing_mask) + + in_X = X.permute(0, 2, 1) + in_X = self.backbone(in_X) + reconstruction = self.projection(in_X) + reconstruction = reconstruction.permute(0, 2, 1) + + if self.apply_nonstationary_norm: + # De-Normalization from Non-stationary Transformer + reconstruction = nonstationary_denorm(reconstruction, means, stdev) + + imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction + results = { + "imputed_data": imputed_data, + } + + # if in training mode, return results with losses + if training: + loss = calc_mse(reconstruction, inputs["X_ori"], inputs["indicating_mask"]) + results["loss"] = loss + + return results diff --git a/pypots/imputation/moderntcn/data.py b/pypots/imputation/moderntcn/data.py new file mode 100644 index 00000000..c296728a --- /dev/null +++ b/pypots/imputation/moderntcn/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for ModernTCN. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForModernTCN(DatasetForSAITS): + """Actually ModernTCN uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_y: bool, + file_type: str = "hdf5", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_y, file_type, rate) diff --git a/pypots/imputation/moderntcn/model.py b/pypots/imputation/moderntcn/model.py new file mode 100644 index 00000000..2efb3fed --- /dev/null +++ b/pypots/imputation/moderntcn/model.py @@ -0,0 +1,342 @@ +""" +The implementation of ModernTCN for the partially-observed time-series imputation task. + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .core import _ModernTCN +from .data import DatasetForModernTCN +from ..base import BaseNNImputer +from ...data.checking import key_in_data_set +from ...data.dataset import BaseDataset +from ...optim.adam import Adam +from ...optim.base import Optimizer + + +class ModernTCN(BaseNNImputer): + """The PyTorch implementation of the ModernTCN model. + ModernTCN is originally proposed by Luo et al. in :cite:`luo2024moderntcn`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + patch_size : + The size of the patch for the patching mechanism. + + patch_stride : + The stride for the patching mechanism. + + downsampling_ratio : + The downsampling ratio for the downsampling mechanism. + + ffn_ratio : + The ratio for the feed-forward neural network in the model. + + num_blocks : + The number of blocks for the model. It should be a list of integers. + + large_size : + The size of the large kernel. It should be a list of odd integers. + + small_size : + The size of the small kernel. It should be a list of odd integers. + + dims : + The dimensions for the model. It should be a list of integers. + + small_kernel_merged : + Whether the small kernel is merged. + + backbone_dropout : + The dropout rate for the backbone of the model. + + head_dropout : + The dropout rate for the head of the model. + + use_multi_scale : + Whether to use multi-scale fusing. + + individual : + Whether to make a linear layer for each variate/channel/feature individually. + + apply_nonstationary_norm : + Whether to apply non-stationary normalization. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + verbose : + Whether to print out the training logs during the training process. + """ + + def __init__( + self, + n_steps: int, + n_features: int, + patch_size: int, + patch_stride: int, + downsampling_ratio: float, + ffn_ratio: float, + num_blocks: list, + large_size: list, + small_size: list, + dims: list, + small_kernel_merged: bool = False, + backbone_dropout: float = 0.1, + head_dropout: float = 0.1, + use_multi_scale: bool = True, + individual: bool = False, + apply_nonstationary_norm: bool = False, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + verbose: bool = True, + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + verbose, + ) + assert ( + len(num_blocks) == len(dims) == len(large_size) == len(small_size) + ), "The length of num_blocks, dims, large_size, and small_size should be the same." + + self.n_steps = n_steps + self.n_features = n_features + + # set up the model + self.model = _ModernTCN( + n_steps, + n_features, + patch_size, + patch_stride, + downsampling_ratio, + ffn_ratio, + num_blocks, + large_size, + small_size, + dims, + small_kernel_merged, + backbone_dropout, + head_dropout, + use_multi_scale, + individual, + apply_nonstationary_norm, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForModernTCN( + train_set, return_X_ori=False, return_y=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForModernTCN( + val_set, return_X_ori=True, return_y=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if test_set is a path string. + + Returns + ------- + file_type : + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, + return_X_ori=False, + return_X_pred=False, + return_y=False, + file_type=file_type, + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + test_set : + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (n_steps), n_features], + Imputed data. + """ + + result_dict = self.predict(test_set, file_type=file_type) + return result_dict["imputation"] From dd8c553b3ce0d27423854d60668c5bbb1026c7f3 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 30 Aug 2024 01:08:42 +0800 Subject: [PATCH 4/6] Update docs (#498) * docs: update readme files; --- README.md | 83 ++++++++++++++++++++++++++-------------------------- README_zh.md | 83 ++++++++++++++++++++++++++-------------------------- 2 files changed, 84 insertions(+), 82 deletions(-) diff --git a/README.md b/README.md index fb77cedd..c11fe903 100644 --- a/README.md +++ b/README.md @@ -113,47 +113,47 @@ The task types are abbreviated as follows: **`ANOD`**: Anomaly Detection. The paper references and links are all listed at the bottom of this file. -| **Type** | **Algo** | **IMPU** | **FORE** | **CLAS** | **CLUS** | **ANOD** | **Year - Venue** | -|:--------------|:----------------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:-------------------| -| LLM | Gungnir 🚀 [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | `Later in 2024` | -| Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` | -| Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | -| Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` | -| Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` | -| Neural Net | Koopa🧑‍🔧[^29] | ✅ | | | | | `2023 - NeurIPS` | -| Neural Net | Crossformer🧑‍🔧[^16] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | TimesNet[^14] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | PatchTST🧑‍🔧[^18] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | ETSformer🧑‍🔧[^19] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | MICN🧑‍🔧[^27] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | DLinear🧑‍🔧[^17] | ✅ | | | | | `2023 - AAAI` | -| Neural Net | TiDE🧑‍🔧[^28] | ✅ | | | | | `2023 - TMLR` | -| Neural Net | SCINet🧑‍🔧[^30] | ✅ | | | | | `2022 - NeurIPS` | -| Neural Net | Nonstationary Tr.🧑‍🔧[^25] | ✅ | | | | | `2022 - NeurIPS` | -| Neural Net | FiLM🧑‍🔧[^22] | ✅ | | | | | `2022 - NeurIPS` | -| Neural Net | RevIN_SCINet🧑‍🔧[^31] | ✅ | | | | | `2022 - ICLR` | -| Neural Net | Pyraformer🧑‍🔧[^26] | ✅ | | | | | `2022 - ICLR` | -| Neural Net | Raindrop[^5] | | | ✅ | | | `2022 - ICLR` | -| Neural Net | FEDformer🧑‍🔧[^20] | ✅ | | | | | `2022 - ICML` | -| Neural Net | Autoformer🧑‍🔧[^15] | ✅ | | | | | `2021 - NeurIPS` | -| Neural Net | CSDI[^12] | ✅ | ✅ | | | | `2021 - NeurIPS` | -| Neural Net | Informer🧑‍🔧[^21] | ✅ | | | | | `2021 - AAAI` | -| Neural Net | US-GAN[^10] | ✅ | | | | | `2021 - AAAI` | -| Neural Net | CRLI[^6] | | | | ✅ | | `2021 - AAAI` | -| Probabilistic | BTTF[^8] | | ✅ | | | | `2021 - TPAMI` | -| Neural Net | StemGNN🧑‍🔧[^33] | ✅ | | | | | `2020 - NeurIPS` | -| Neural Net | Reformer🧑‍🔧[^32] | ✅ | | | | | `2020 - ICLR` | -| Neural Net | GP-VAE[^11] | ✅ | | | | | `2020 - AISTATS` | -| Neural Net | VaDER[^7] | | | | ✅ | | `2019 - GigaSci.` | -| Neural Net | M-RNN[^9] | ✅ | | | | | `2019 - TBME` | -| Neural Net | BRITS[^3] | ✅ | | ✅ | | | `2018 - NeurIPS` | -| Neural Net | GRU-D[^4] | ✅ | | ✅ | | | `2018 - Sci. Rep.` | -| Neural Net | TCN🧑‍🔧[^35] | ✅ | | | | | `2018 - arXiv` | -| Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` | -| Naive | Lerp | ✅ | | | | | | -| Naive | LOCF/NOCB | ✅ | | | | | | -| Naive | Mean | ✅ | | | | | | -| Naive | Median | ✅ | | | | | | +| **Type** | **Algo** | **IMPU** | **FORE** | **CLAS** | **CLUS** | **ANOD** | **Year - Venue** | +|:--------------|:---------------------------------------------------------------------------------------------------------------------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:-------------------| +| LLM | Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | `Later in 2024` | +| Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` | +| Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | +| Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` | +| Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` | +| Neural Net | Koopa🧑‍🔧[^29] | ✅ | | | | | `2023 - NeurIPS` | +| Neural Net | Crossformer🧑‍🔧[^16] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | TimesNet[^14] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | PatchTST🧑‍🔧[^18] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | ETSformer🧑‍🔧[^19] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | MICN🧑‍🔧[^27] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | DLinear🧑‍🔧[^17] | ✅ | | | | | `2023 - AAAI` | +| Neural Net | TiDE🧑‍🔧[^28] | ✅ | | | | | `2023 - TMLR` | +| Neural Net | SCINet🧑‍🔧[^30] | ✅ | | | | | `2022 - NeurIPS` | +| Neural Net | Nonstationary Tr.🧑‍🔧[^25] | ✅ | | | | | `2022 - NeurIPS` | +| Neural Net | FiLM🧑‍🔧[^22] | ✅ | | | | | `2022 - NeurIPS` | +| Neural Net | RevIN_SCINet🧑‍🔧[^31] | ✅ | | | | | `2022 - ICLR` | +| Neural Net | Pyraformer🧑‍🔧[^26] | ✅ | | | | | `2022 - ICLR` | +| Neural Net | Raindrop[^5] | | | ✅ | | | `2022 - ICLR` | +| Neural Net | FEDformer🧑‍🔧[^20] | ✅ | | | | | `2022 - ICML` | +| Neural Net | Autoformer🧑‍🔧[^15] | ✅ | | | | | `2021 - NeurIPS` | +| Neural Net | CSDI[^12] | ✅ | ✅ | | | | `2021 - NeurIPS` | +| Neural Net | Informer🧑‍🔧[^21] | ✅ | | | | | `2021 - AAAI` | +| Neural Net | US-GAN[^10] | ✅ | | | | | `2021 - AAAI` | +| Neural Net | CRLI[^6] | | | | ✅ | | `2021 - AAAI` | +| Probabilistic | BTTF[^8] | | ✅ | | | | `2021 - TPAMI` | +| Neural Net | StemGNN🧑‍🔧[^33] | ✅ | | | | | `2020 - NeurIPS` | +| Neural Net | Reformer🧑‍🔧[^32] | ✅ | | | | | `2020 - ICLR` | +| Neural Net | GP-VAE[^11] | ✅ | | | | | `2020 - AISTATS` | +| Neural Net | VaDER[^7] | | | | ✅ | | `2019 - GigaSci.` | +| Neural Net | M-RNN[^9] | ✅ | | | | | `2019 - TBME` | +| Neural Net | BRITS[^3] | ✅ | | ✅ | | | `2018 - NeurIPS` | +| Neural Net | GRU-D[^4] | ✅ | | ✅ | | | `2018 - Sci. Rep.` | +| Neural Net | TCN🧑‍🔧[^35] | ✅ | | | | | `2018 - arXiv` | +| Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` | +| Naive | Lerp | ✅ | | | | | | +| Naive | LOCF/NOCB | ✅ | | | | | | +| Naive | Mean | ✅ | | | | | | +| Naive | Median | ✅ | | | | | | 💯 Contribute your model right now to increase your research impact! PyPOTS downloads are increasing rapidly (**[300K+ in total and 1K+ daily on PyPI so far](https://www.pepy.tech/projects/pypots)**), and your work will be widely used and cited by the community. @@ -394,3 +394,4 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together [^35]: Bai, S., Kolter, J. Z., & Koltun, V. (2018). [An empirical evaluation of generic convolutional and recurrent networks for sequence modeling](https://arxiv.org/abs/1803.01271). *arXiv 2018*. [^36]: Project Gungnir, the world 1st LLM for time-series multitask modeling, will meet you soon. 🚀 Missing values and variable lengths in your datasets? Hard to perform multitask learning with your time series? Not problems no longer. We'll open application for public beta test recently ;-) Follow us, and stay tuned! + Time-Series.AI diff --git a/README_zh.md b/README_zh.md index b04dd28d..e0c6eb08 100644 --- a/README_zh.md +++ b/README_zh.md @@ -99,47 +99,47 @@ PyPOTS当前支持多变量POTS数据的插补,预测,分类,聚类以及 所以这些模型的输入中不能带有缺失值,无法接受POTS数据作为输入,更加不是插补算法。 **为了使上述模型能够适用于POTS数据,我们采用了与[SAITS论文](https://arxiv.org/pdf/2202.08516)[^1]中相同的embedding策略和训练方法(ORT+MIT)对它们进行改进**。 -| **类型** | **算法** | **插补** | **预测** | **分类** | **聚类** | **异常检测** | **年份 - 刊物** | -|:--------------|:----------------------------|:------:|:------:|:------:|:------:|:--------:|:-------------------| -| LLM | Gungnir 🚀 [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | `Later in 2024` | -| Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` | -| Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | -| Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` | -| Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` | -| Neural Net | Koopa🧑‍🔧[^29] | ✅ | | | | | `2023 - NeurIPS` | -| Neural Net | Crossformer🧑‍🔧[^16] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | TimesNet[^14] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | PatchTST🧑‍🔧[^18] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | ETSformer🧑‍🔧[^19] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | MICN🧑‍🔧[^27] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | DLinear🧑‍🔧[^17] | ✅ | | | | | `2023 - AAAI` | -| Neural Net | TiDE🧑‍🔧[^28] | ✅ | | | | | `2023 - TMLR` | -| Neural Net | SCINet🧑‍🔧[^30] | ✅ | | | | | `2022 - NeurIPS` | -| Neural Net | Nonstationary Tr.🧑‍🔧[^25] | ✅ | | | | | `2022 - NeurIPS` | -| Neural Net | FiLM🧑‍🔧[^22] | ✅ | | | | | `2022 - NeurIPS` | -| Neural Net | RevIN_SCINet🧑‍🔧[^31] | ✅ | | | | | `2022 - ICLR` | -| Neural Net | Pyraformer🧑‍🔧[^26] | ✅ | | | | | `2022 - ICLR` | -| Neural Net | Raindrop[^5] | | | ✅ | | | `2022 - ICLR` | -| Neural Net | FEDformer🧑‍🔧[^20] | ✅ | | | | | `2022 - ICML` | -| Neural Net | Autoformer🧑‍🔧[^15] | ✅ | | | | | `2021 - NeurIPS` | -| Neural Net | CSDI[^12] | ✅ | ✅ | | | | `2021 - NeurIPS` | -| Neural Net | Informer🧑‍🔧[^21] | ✅ | | | | | `2021 - AAAI` | -| Neural Net | US-GAN[^10] | ✅ | | | | | `2021 - AAAI` | -| Neural Net | CRLI[^6] | | | | ✅ | | `2021 - AAAI` | -| Probabilistic | BTTF[^8] | | ✅ | | | | `2021 - TPAMI` | -| Neural Net | StemGNN🧑‍🔧[^33] | ✅ | | | | | `2020 - NeurIPS` | -| Neural Net | Reformer🧑‍🔧[^32] | ✅ | | | | | `2020 - ICLR` | -| Neural Net | GP-VAE[^11] | ✅ | | | | | `2020 - AISTATS` | -| Neural Net | VaDER[^7] | | | | ✅ | | `2019 - GigaSci.` | -| Neural Net | M-RNN[^9] | ✅ | | | | | `2019 - TBME` | -| Neural Net | BRITS[^3] | ✅ | | ✅ | | | `2018 - NeurIPS` | -| Neural Net | GRU-D[^4] | ✅ | | ✅ | | | `2018 - Sci. Rep.` | -| Neural Net | TCN🧑‍🔧[^35] | ✅ | | | | | `2018 - arXiv` | -| Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` | -| Naive | Lerp | ✅ | | | | | | -| Naive | LOCF/NOCB | ✅ | | | | | | -| Naive | Mean | ✅ | | | | | | -| Naive | Median | ✅ | | | | | | +| **类型** | **算法** | **插补** | **预测** | **分类** | **聚类** | **异常检测** | **年份 - 刊物** | +|:--------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:------:|:------:|:------:|:--------:|:-------------------| +| LLM | Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | `Later in 2024` | +| Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` | +| Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | +| Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` | +| Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` | +| Neural Net | Koopa🧑‍🔧[^29] | ✅ | | | | | `2023 - NeurIPS` | +| Neural Net | Crossformer🧑‍🔧[^16] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | TimesNet[^14] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | PatchTST🧑‍🔧[^18] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | ETSformer🧑‍🔧[^19] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | MICN🧑‍🔧[^27] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | DLinear🧑‍🔧[^17] | ✅ | | | | | `2023 - AAAI` | +| Neural Net | TiDE🧑‍🔧[^28] | ✅ | | | | | `2023 - TMLR` | +| Neural Net | SCINet🧑‍🔧[^30] | ✅ | | | | | `2022 - NeurIPS` | +| Neural Net | Nonstationary Tr.🧑‍🔧[^25] | ✅ | | | | | `2022 - NeurIPS` | +| Neural Net | FiLM🧑‍🔧[^22] | ✅ | | | | | `2022 - NeurIPS` | +| Neural Net | RevIN_SCINet🧑‍🔧[^31] | ✅ | | | | | `2022 - ICLR` | +| Neural Net | Pyraformer🧑‍🔧[^26] | ✅ | | | | | `2022 - ICLR` | +| Neural Net | Raindrop[^5] | | | ✅ | | | `2022 - ICLR` | +| Neural Net | FEDformer🧑‍🔧[^20] | ✅ | | | | | `2022 - ICML` | +| Neural Net | Autoformer🧑‍🔧[^15] | ✅ | | | | | `2021 - NeurIPS` | +| Neural Net | CSDI[^12] | ✅ | ✅ | | | | `2021 - NeurIPS` | +| Neural Net | Informer🧑‍🔧[^21] | ✅ | | | | | `2021 - AAAI` | +| Neural Net | US-GAN[^10] | ✅ | | | | | `2021 - AAAI` | +| Neural Net | CRLI[^6] | | | | ✅ | | `2021 - AAAI` | +| Probabilistic | BTTF[^8] | | ✅ | | | | `2021 - TPAMI` | +| Neural Net | StemGNN🧑‍🔧[^33] | ✅ | | | | | `2020 - NeurIPS` | +| Neural Net | Reformer🧑‍🔧[^32] | ✅ | | | | | `2020 - ICLR` | +| Neural Net | GP-VAE[^11] | ✅ | | | | | `2020 - AISTATS` | +| Neural Net | VaDER[^7] | | | | ✅ | | `2019 - GigaSci.` | +| Neural Net | M-RNN[^9] | ✅ | | | | | `2019 - TBME` | +| Neural Net | BRITS[^3] | ✅ | | ✅ | | | `2018 - NeurIPS` | +| Neural Net | GRU-D[^4] | ✅ | | ✅ | | | `2018 - Sci. Rep.` | +| Neural Net | TCN🧑‍🔧[^35] | ✅ | | | | | `2018 - arXiv` | +| Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` | +| Naive | Lerp | ✅ | | | | | | +| Naive | LOCF/NOCB | ✅ | | | | | | +| Naive | Mean | ✅ | | | | | | +| Naive | Median | ✅ | | | | | | 💯 现在贡献你的模型来增加你的研究影响力!PyPOTS的下载量正在迅速增长(**[目前PyPI上总共超过30万次且每日超1000的下载](https://www.pepy.tech/projects/pypots)**), 你的工作将被社区广泛使用和引用。请参阅[贡献指南](https://github.com/WenjieDu/PyPOTS/blob/main/README_zh.md#-%E8%B4%A1%E7%8C%AE%E5%A3%B0%E6%98%8E),了解如何将模型包含在PyPOTS中。 @@ -363,3 +363,4 @@ PyPOTS社区是一个开放、透明、友好的社区,让我们共同努力 [^34]: Nie, T., Qin, G., Mei, Y., & Sun, J. (2024). [ImputeFormer: Low Rankness-Induced Transformers for Generalizable Spatiotemporal Imputation](https://arxiv.org/abs/2312.01728). *KDD 2024*. [^35]: Bai, S., Kolter, J. Z., & Koltun, V. (2018). [An empirical evaluation of generic convolutional and recurrent networks for sequence modeling](https://arxiv.org/abs/1803.01271). *arXiv 2018*. [^36]: Gungnir项目,世界上第一个时间序列多任务大模型,将很快与大家见面。🚀 数据集存在缺少值且样本长短不一?多任务建模场景困难?都不再是问题,让我们的大模型来帮你解决。我们将在近期开放公测申请 ;-) 关注我们,敬请期待! + Time-Series.AI From 09751a74f68abb267746ad08769ceb9257d2d910 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Tue, 3 Sep 2024 16:59:01 +0800 Subject: [PATCH 5/6] Update the docs for TimeMixer (#500) --- .idea/icon.png | Bin 0 -> 5373 bytes README.md | 87 +++++++++++++------------- README_zh.md | 87 +++++++++++++------------- docs/examples.rst | 16 +++-- docs/index.rst | 6 +- docs/pypots.imputation.rst | 9 +++ pypots/data/load_specific_datasets.py | 4 +- tests/imputation/timemixer.py | 4 +- 8 files changed, 114 insertions(+), 99 deletions(-) create mode 100644 .idea/icon.png diff --git a/.idea/icon.png b/.idea/icon.png new file mode 100644 index 0000000000000000000000000000000000000000..0a58e1cc5446786d166679f273486575e2b556ab GIT binary patch literal 5373 zcmZ`-XEYlC+l>(;LNzumLaf+Z%_4Th-qc=2sa@1eG-{^yC}LBqYIbOi5~ZZome{G; zmQvJe@%26D{m%LRd_SIZ@44sRd!ApWxxpH1_1y7Mk7OnB>+I7^KW$NsQxL@ z(#H({V9y+Bt#^5O`H!BTpYQJOzIyd4H8u6<=%}Wq#@E*u3WaWOZxe|`K0ZEeZS8@9 z0SpG?=H{lUsX0GC-__M+WMuT`&!5iDP9P9iP*5;6HO0WdproX9a&mI}_U+>0Vh98x zCMGsLJ-xrbPa=_=ot>AKmPSTK7#SInNaX$d_c=H?+S=N1I9yCj%(rjf$Yk=w!~_C? z=lphczE!5yrZL| zjg3unb8}f)nVFecR#sM0Qj)s5x~i(GgoK2Ffq}HN^y=!WmzS5TYf2pez(O!WXj@0m zZ7=xS&E4coB_Gp67zA$rU?zcSs7cCtxMVCt@8`?Q;F99XL^etwblBN-`&XxISQJEf z949ey<(j>}+IVOF)TpGwVFa?5o?1d7TmERO#)ejC#w^AJDNB zy1T-W9Cqb@k9fH=HKN^O!!&Ar z^LH`DN~r`@O`I@mfBwB5!Dv84)MlhfolP3wg{?aSInEo%mHv}+8}ZGWQxeA-op@fe z4|Ifd`?t}nz#MD5XZPF?+ZSC0k!{sW{tCY;JrJhVZAY6lROMSi7FCM%LW|zO-B0lQ zEdjt_k4ym)kYQnE@mXPiR>7N6RL-`Z8llqc_q1JFEb(}?X!}7>46#SowVEs0y$y9V z%nbs!W%R$12I@2zA4S}y8G40eD8!I|T_Md0fn%>a%4mV_=__#r>iQXVnj3)7Y8|f! zSomQKZ>fADz9KtcT?Sh!Ww6C z;l>8Au;bMIS|W26@g}!($wt8?lbw(hXPzl{j{H|A?)OGkmE*I2->EPsyO*XqOKotk zb1$rCET)itm>g%-_P&U+xx|nTGB_D=R>XULsCgpBtE8Ya|3%N98k4b#>xV7A3s$fs zcQ;ilBCHP(g|ceelyC1Z1yoyf(55DmwtQzEuU|r$M@r0jxHB7c{o*3QO_2AV31`o= z-CCCJn#b&Kp!EjvR~3y5cB7qPk*!$t zj5u?nN^vE$p4elQ>%G~Jibvlm6jrlzkG?Pp(D_MWwP=tvbq=^We0Lh54NQJ+J5`r8 zE_3Jyq4`;IQA#7FcF)z4`5n+v+>$Dls$|%=111|582_w`m>TjtewXWvt;h^VyLHQ7 z^=(Q*SxKuztsMGV3;x2$`P69j@y+9KnaS&n3~m`!q}30Ke}N9E6=GPOtK3gg((=AR zV*HySXoZ%>N0dMFMq(Zd;b$(hrBpaF1zlSHtbUL{t7aZhJ?FTz_wpEmkrzwUd0+Lm zL>P+vSrp;B{WxX5v@E1_q(mOyoNV$apRK7aNcLf7_(`tftzjyu_9`^(O5|8ajpZ4A zqGjvna(t|68@m>tXS5={1CQ$44wLWI)8NDhk6jD10dXI;H>v!S4%%Su?9Sy9M*Z;= zdCN9-9aQ`a*lcOW)~-ir$&i9FG_DM!Zs)!KfTH4sVQyJsSPMUX{JHO-F^}=iY;?kF zXIvfk`XUHoI}k+7TZ?|#&jzM3c`q33^|ztMLp^@@(y_xjk3`fucC8Cgjk@F#PCwYm zPYyQbNHP7CctE*)A4rB3KeFB`x69d!sH=C{W}_;frpzRd-rNdI!SQ>~>rSb|1R5V) zk(+bEQODI@tlOfh#6DG8&F-_DRhkg^it!(NOtC_8ln=JfCw}-AMOIH>W4=gfWpvzrtf*seLGX z4o@liR|Mg4b2b1LPr=LiG2% z7O5RBGsChx>%_1$L*9~Imz()%T^VfOgv<;N!$SC#ar^_y=aWK}^G-2?>fyX4H}ET# zNv-}mk5EbW!_ht7>0q{}Yz4GTwZ~{6FesCn@iw}b39eAf8(v&1q_Et5Pe(TGHsbAz zvewf4hGLe4eUl-2J!w7RrYnnG89d5Kg6~y%CUoFu>7O|C%6nnza}vP-xQW?p($DP# zmhwF`&X+T*6O7JLDv@!6At`glR<2JFvzpB;yM^X*P=kp<rzk<)SF+ zOObP18WbOHVe}!w6bK>NVPdQfaL){B_*;oL`!xF-f4JI^%8VesQ2MsiN&Jq%(gvwZ zbX8kl_?skjdaP8eTH*XGC&V*UXaA?}jqQZEC}$-T`!6#S?pvQKc3F0c)7m`x)6gRy z&!AK;c5;xqI+W;wSntfqRGdzoR~tGw&IyLK3H#8iXpZkeQ!(m4e>Xw@?d0lHN}T;U zbzl*3`oOYcjfNZ1ai!@<)_tpPEaRM0Gsitfigi}p!(-|hNNGDy#sQvx5gOI%ppx9N~WQeKdG={ZL3EC==!7kUh_LrGThlbo#3!>O_o5YNU}~BDE)1Ze24TiHv33Z zsIJ2a(r4%JTjxe^`gYW;^++eSdEQCqwcPxLYLc1FxGU*r*bNcuh-I zz38lks-jJFLnR2Lv^*PEcPm9fM(2iS&Nk_%rQ*LPy(H{lNAeM%SK2#5N!le1?Tye& zQ^@AEdqOr^($4((VQiHN`m4)=m0xTSj@ypOO`>1OHQN1^HA)AG-F=Va`>8_atkUAl zd$GujokO)a27COO^YrVFTxuh6=QkQMnsT^m-pq7M*+4Z7d2PH1dX~{BsV?L?uX3ivV;SI&tQno}DR^6)~*6m$` z=Gq~wM1!Et@pQ*JOlm_f1Z`J<+dz*NsypbS&6t9z)a32`n#Vf0@~T*fH3tMXP71 zzx6otx@S%7rguQS?KEx0mA0_a`dm8TZ0iL$ebBM{)kgl7xYt33!J6O+VX>Gs0yBNj z$8c#2wj!@);M2+0ac?B|jR^J4x+e&GeBj?vb-_($=B``eY>yrts84-5h&H)sk07ee ztC)0|VD?*SKfd*;e)$yl{4$S|Ecs`L0Ef*zi7xS-pkCTL3_$W}ll;!nDM*dypvGU; z{B{p%u5m6S_me^?nWR1w;D=2bsxu=hP;C8>o|}uQ68yP#@lZ*ZAaK_-r`Pl5lH8c}*E}MoA;lnJNR~;r<~2 z=4P{pp)F1Tr8NkAJ8fwPKZv2ovL>e&qvy(u$!_(k<55?hI~-V#E%ez0zF}_Ro;bCpV~p6P_tu4OF2jF)7$WM8B8E zFvjj_;cV?w-_y4w!Gt++w$aq3@u7^L-0%noa@5|N<)fbn@gRQi(B35w35GjL)(Fs5 zVRgK%#IW|P_@4`xi+S{O>mhSU_D97zo0R(Zvx`)WE)Km>4Z75J5D}y{-HkMsFHkf~ z8jCO#ojgphNH>t9LIw>p=PHwd!{&v`?o1wA?@B`$Jp2fyA#}2^b<=1{+0Hf7Xe!y^ zbyLw1DJo*2G8sM`!XG@SG;IDznJhMJ{x7UKTuCTx2SUNHC7&%Fx(XP|9?jehLm7Hw zSF|_K(#+b98)%QR&!5&x0?vEj)H%oAq6cJBnVK}DUmH-U2@6-XDAENE{5X}FnY*iV zLy4Oyu20fP^8E*Oqaytk_@m`@_?xlcO>6XxX_d9!X3^=1-W?qx){N9YO8=TAqnMi& zPg!^c&8OU>c~7~txiL_(24&t6P7(wP;9m~T`iPh2ci+tL+jK>}^8xMW)dsDYd|HthM zfqk9kf&{$QhfKN#H|IYUJ?+fp3)Rtv)!1=7>-Xfa$L?ddEsoxVrx--fxdZYRkauh~Z((3guY`|EX>oq{*8 z3oxY6Maykm&WKyT?-=m+3Hd9EdT5EnXWTNWDw4izI<_c zQq_mo`jJ6EVsUJin&A9e*}yaT$I)*mD&ga}(~u0VYQ4+aB48__&#pE7>06GM;(iVg zS5Iuzy5^FFQgn9^ctyD8W-*Gpz>X8 z-$BXZc(w9cR#qb=DG$CQ97W9%9Pa6rO(SgjGccKiuRWaj1(VvvUWhpAtGR`)ePKpE zDj@O}LHfFAHf#zQsjn=i^a&r|db2&6lFQBS!gBsFs#G(fwT^A;6eNp2uCY zLG({iutl__qMSH z9J1nA*o^BhJ_p7ZzRPU;qnh=!8aDVQd{^U7Oh(3z7S#D&pXO=H#j0h&iyR4!o+wLJ zI`630_p5Y$)?c;chi4=gJbcxabl6?%wBkPhFe{1nhci`X5X0S-JoHogYdN(lQ@^(; z!>@dC0D5?lWsEt6G9dc&)o6=BQTIFEm z=a5ZAiK_px?)^ag3lFrh{HvjPjr!NX`B(L}oP)?1U4gm2vC7F) Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | `Later in 2024` | +| Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` | +| Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | +| Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` | +| Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` | +| Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` | +| Neural Net | Koopa🧑‍🔧[^29] | ✅ | | | | | `2023 - NeurIPS` | +| Neural Net | Crossformer🧑‍🔧[^16] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | TimesNet[^14] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | PatchTST🧑‍🔧[^18] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | ETSformer🧑‍🔧[^19] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | MICN🧑‍🔧[^27] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | DLinear🧑‍🔧[^17] | ✅ | | | | | `2023 - AAAI` | +| Neural Net | TiDE🧑‍🔧[^28] | ✅ | | | | | `2023 - TMLR` | +| Neural Net | SCINet🧑‍🔧[^30] | ✅ | | | | | `2022 - NeurIPS` | +| Neural Net | Nonstationary Tr.🧑‍🔧[^25] | ✅ | | | | | `2022 - NeurIPS` | +| Neural Net | FiLM🧑‍🔧[^22] | ✅ | | | | | `2022 - NeurIPS` | +| Neural Net | RevIN_SCINet🧑‍🔧[^31] | ✅ | | | | | `2022 - ICLR` | +| Neural Net | Pyraformer🧑‍🔧[^26] | ✅ | | | | | `2022 - ICLR` | +| Neural Net | Raindrop[^5] | | | ✅ | | | `2022 - ICLR` | +| Neural Net | FEDformer🧑‍🔧[^20] | ✅ | | | | | `2022 - ICML` | +| Neural Net | Autoformer🧑‍🔧[^15] | ✅ | | | | | `2021 - NeurIPS` | +| Neural Net | CSDI[^12] | ✅ | ✅ | | | | `2021 - NeurIPS` | +| Neural Net | Informer🧑‍🔧[^21] | ✅ | | | | | `2021 - AAAI` | +| Neural Net | US-GAN[^10] | ✅ | | | | | `2021 - AAAI` | +| Neural Net | CRLI[^6] | | | | ✅ | | `2021 - AAAI` | +| Probabilistic | BTTF[^8] | | ✅ | | | | `2021 - TPAMI` | +| Neural Net | StemGNN🧑‍🔧[^33] | ✅ | | | | | `2020 - NeurIPS` | +| Neural Net | Reformer🧑‍🔧[^32] | ✅ | | | | | `2020 - ICLR` | +| Neural Net | GP-VAE[^11] | ✅ | | | | | `2020 - AISTATS` | +| Neural Net | VaDER[^7] | | | | ✅ | | `2019 - GigaSci.` | +| Neural Net | M-RNN[^9] | ✅ | | | | | `2019 - TBME` | +| Neural Net | BRITS[^3] | ✅ | | ✅ | | | `2018 - NeurIPS` | +| Neural Net | GRU-D[^4] | ✅ | | ✅ | | | `2018 - Sci. Rep.` | +| Neural Net | TCN🧑‍🔧[^35] | ✅ | | | | | `2018 - arXiv` | +| Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` | +| Naive | Lerp | ✅ | | | | | | +| Naive | LOCF/NOCB | ✅ | | | | | | +| Naive | Mean | ✅ | | | | | | +| Naive | Median | ✅ | | | | | | 💯 Contribute your model right now to increase your research impact! PyPOTS downloads are increasing rapidly (**[300K+ in total and 1K+ daily on PyPI so far](https://www.pepy.tech/projects/pypots)**), and your work will be widely used and cited by the community. @@ -320,7 +321,7 @@ By committing your code, you'll [pypots/imputation/template](https://github.com/WenjieDu/PyPOTS/tree/main/pypots/imputation/template)) to quickly start; 2. become one of [PyPOTS contributors](https://github.com/WenjieDu/PyPOTS/graphs/contributors) and be listed as a volunteer developer [on the PyPOTS website](https://pypots.com/about/#volunteer-developers); -3. get mentioned in our [release notes](https://github.com/WenjieDu/PyPOTS/releases); +3. get mentioned in PyPOTS [release notes](https://github.com/WenjieDu/PyPOTS/releases); You can also contribute to PyPOTS by simply staring🌟 this repo to help more people notice it. Your star is your recognition to PyPOTS, and it matters! @@ -394,3 +395,5 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together [^35]: Bai, S., Kolter, J. Z., & Koltun, V. (2018). [An empirical evaluation of generic convolutional and recurrent networks for sequence modeling](https://arxiv.org/abs/1803.01271). *arXiv 2018*. [^36]: Project Gungnir, the world 1st LLM for time-series multitask modeling, will meet you soon. 🚀 Missing values and variable lengths in your datasets? Hard to perform multitask learning with your time series? Not problems no longer. We'll open application for public beta test recently ;-) Follow us, and stay tuned! + Time-Series.AI +[^37]: Wang, S., Wu, H., Shi, X., Hu, T., Luo, H., Ma, L., ... & ZHOU, J. (2024). [TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting](https://openreview.net/forum?id=7oLshfEIC2). *ICLR 2024* diff --git a/README_zh.md b/README_zh.md index b04dd28d..d3f760cc 100644 --- a/README_zh.md +++ b/README_zh.md @@ -99,47 +99,48 @@ PyPOTS当前支持多变量POTS数据的插补,预测,分类,聚类以及 所以这些模型的输入中不能带有缺失值,无法接受POTS数据作为输入,更加不是插补算法。 **为了使上述模型能够适用于POTS数据,我们采用了与[SAITS论文](https://arxiv.org/pdf/2202.08516)[^1]中相同的embedding策略和训练方法(ORT+MIT)对它们进行改进**。 -| **类型** | **算法** | **插补** | **预测** | **分类** | **聚类** | **异常检测** | **年份 - 刊物** | -|:--------------|:----------------------------|:------:|:------:|:------:|:------:|:--------:|:-------------------| -| LLM | Gungnir 🚀 [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | `Later in 2024` | -| Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` | -| Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | -| Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` | -| Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` | -| Neural Net | Koopa🧑‍🔧[^29] | ✅ | | | | | `2023 - NeurIPS` | -| Neural Net | Crossformer🧑‍🔧[^16] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | TimesNet[^14] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | PatchTST🧑‍🔧[^18] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | ETSformer🧑‍🔧[^19] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | MICN🧑‍🔧[^27] | ✅ | | | | | `2023 - ICLR` | -| Neural Net | DLinear🧑‍🔧[^17] | ✅ | | | | | `2023 - AAAI` | -| Neural Net | TiDE🧑‍🔧[^28] | ✅ | | | | | `2023 - TMLR` | -| Neural Net | SCINet🧑‍🔧[^30] | ✅ | | | | | `2022 - NeurIPS` | -| Neural Net | Nonstationary Tr.🧑‍🔧[^25] | ✅ | | | | | `2022 - NeurIPS` | -| Neural Net | FiLM🧑‍🔧[^22] | ✅ | | | | | `2022 - NeurIPS` | -| Neural Net | RevIN_SCINet🧑‍🔧[^31] | ✅ | | | | | `2022 - ICLR` | -| Neural Net | Pyraformer🧑‍🔧[^26] | ✅ | | | | | `2022 - ICLR` | -| Neural Net | Raindrop[^5] | | | ✅ | | | `2022 - ICLR` | -| Neural Net | FEDformer🧑‍🔧[^20] | ✅ | | | | | `2022 - ICML` | -| Neural Net | Autoformer🧑‍🔧[^15] | ✅ | | | | | `2021 - NeurIPS` | -| Neural Net | CSDI[^12] | ✅ | ✅ | | | | `2021 - NeurIPS` | -| Neural Net | Informer🧑‍🔧[^21] | ✅ | | | | | `2021 - AAAI` | -| Neural Net | US-GAN[^10] | ✅ | | | | | `2021 - AAAI` | -| Neural Net | CRLI[^6] | | | | ✅ | | `2021 - AAAI` | -| Probabilistic | BTTF[^8] | | ✅ | | | | `2021 - TPAMI` | -| Neural Net | StemGNN🧑‍🔧[^33] | ✅ | | | | | `2020 - NeurIPS` | -| Neural Net | Reformer🧑‍🔧[^32] | ✅ | | | | | `2020 - ICLR` | -| Neural Net | GP-VAE[^11] | ✅ | | | | | `2020 - AISTATS` | -| Neural Net | VaDER[^7] | | | | ✅ | | `2019 - GigaSci.` | -| Neural Net | M-RNN[^9] | ✅ | | | | | `2019 - TBME` | -| Neural Net | BRITS[^3] | ✅ | | ✅ | | | `2018 - NeurIPS` | -| Neural Net | GRU-D[^4] | ✅ | | ✅ | | | `2018 - Sci. Rep.` | -| Neural Net | TCN🧑‍🔧[^35] | ✅ | | | | | `2018 - arXiv` | -| Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` | -| Naive | Lerp | ✅ | | | | | | -| Naive | LOCF/NOCB | ✅ | | | | | | -| Naive | Mean | ✅ | | | | | | -| Naive | Median | ✅ | | | | | | +| **类型** | **算法** | **插补** | **预测** | **分类** | **聚类** | **异常检测** | **年份 - 刊物** | +|:--------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:------:|:------:|:------:|:--------:|:-------------------| +| LLM | Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | `Later in 2024` | +| Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` | +| Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | +| Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` | +| Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` | +| Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` | +| Neural Net | Koopa🧑‍🔧[^29] | ✅ | | | | | `2023 - NeurIPS` | +| Neural Net | Crossformer🧑‍🔧[^16] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | TimesNet[^14] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | PatchTST🧑‍🔧[^18] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | ETSformer🧑‍🔧[^19] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | MICN🧑‍🔧[^27] | ✅ | | | | | `2023 - ICLR` | +| Neural Net | DLinear🧑‍🔧[^17] | ✅ | | | | | `2023 - AAAI` | +| Neural Net | TiDE🧑‍🔧[^28] | ✅ | | | | | `2023 - TMLR` | +| Neural Net | SCINet🧑‍🔧[^30] | ✅ | | | | | `2022 - NeurIPS` | +| Neural Net | Nonstationary Tr.🧑‍🔧[^25] | ✅ | | | | | `2022 - NeurIPS` | +| Neural Net | FiLM🧑‍🔧[^22] | ✅ | | | | | `2022 - NeurIPS` | +| Neural Net | RevIN_SCINet🧑‍🔧[^31] | ✅ | | | | | `2022 - ICLR` | +| Neural Net | Pyraformer🧑‍🔧[^26] | ✅ | | | | | `2022 - ICLR` | +| Neural Net | Raindrop[^5] | | | ✅ | | | `2022 - ICLR` | +| Neural Net | FEDformer🧑‍🔧[^20] | ✅ | | | | | `2022 - ICML` | +| Neural Net | Autoformer🧑‍🔧[^15] | ✅ | | | | | `2021 - NeurIPS` | +| Neural Net | CSDI[^12] | ✅ | ✅ | | | | `2021 - NeurIPS` | +| Neural Net | Informer🧑‍🔧[^21] | ✅ | | | | | `2021 - AAAI` | +| Neural Net | US-GAN[^10] | ✅ | | | | | `2021 - AAAI` | +| Neural Net | CRLI[^6] | | | | ✅ | | `2021 - AAAI` | +| Probabilistic | BTTF[^8] | | ✅ | | | | `2021 - TPAMI` | +| Neural Net | StemGNN🧑‍🔧[^33] | ✅ | | | | | `2020 - NeurIPS` | +| Neural Net | Reformer🧑‍🔧[^32] | ✅ | | | | | `2020 - ICLR` | +| Neural Net | GP-VAE[^11] | ✅ | | | | | `2020 - AISTATS` | +| Neural Net | VaDER[^7] | | | | ✅ | | `2019 - GigaSci.` | +| Neural Net | M-RNN[^9] | ✅ | | | | | `2019 - TBME` | +| Neural Net | BRITS[^3] | ✅ | | ✅ | | | `2018 - NeurIPS` | +| Neural Net | GRU-D[^4] | ✅ | | ✅ | | | `2018 - Sci. Rep.` | +| Neural Net | TCN🧑‍🔧[^35] | ✅ | | | | | `2018 - arXiv` | +| Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` | +| Naive | Lerp | ✅ | | | | | | +| Naive | LOCF/NOCB | ✅ | | | | | | +| Naive | Mean | ✅ | | | | | | +| Naive | Median | ✅ | | | | | | 💯 现在贡献你的模型来增加你的研究影响力!PyPOTS的下载量正在迅速增长(**[目前PyPI上总共超过30万次且每日超1000的下载](https://www.pepy.tech/projects/pypots)**), 你的工作将被社区广泛使用和引用。请参阅[贡献指南](https://github.com/WenjieDu/PyPOTS/blob/main/README_zh.md#-%E8%B4%A1%E7%8C%AE%E5%A3%B0%E6%98%8E),了解如何将模型包含在PyPOTS中。 @@ -292,7 +293,7 @@ year={2023}, [pypots/imputation/template](https://github.com/WenjieDu/PyPOTS/tree/main/pypots/imputation/template))快速启动你的开发; 2. 成为[PyPOTS贡献者](https://github.com/WenjieDu/PyPOTS/graphs/contributors)之一, 并在[PyPOTS网站](https://pypots.com/about/#volunteer-developers)上被列为志愿开发者; -3. 在我们发布新版本的[更新日志](https://github.com/WenjieDu/PyPOTS/releases)中被提及; +3. 在PyPOTS发布新版本的[更新日志](https://github.com/WenjieDu/PyPOTS/releases)中被提及; 你也可以通过为该项目设置星标🌟,帮助更多人关注它。你的星标🌟既是对PyPOTS的认可,也是对PyPOTS发展所做出的重要贡献! @@ -363,3 +364,5 @@ PyPOTS社区是一个开放、透明、友好的社区,让我们共同努力 [^34]: Nie, T., Qin, G., Mei, Y., & Sun, J. (2024). [ImputeFormer: Low Rankness-Induced Transformers for Generalizable Spatiotemporal Imputation](https://arxiv.org/abs/2312.01728). *KDD 2024*. [^35]: Bai, S., Kolter, J. Z., & Koltun, V. (2018). [An empirical evaluation of generic convolutional and recurrent networks for sequence modeling](https://arxiv.org/abs/1803.01271). *arXiv 2018*. [^36]: Gungnir项目,世界上第一个时间序列多任务大模型,将很快与大家见面。🚀 数据集存在缺少值且样本长短不一?多任务建模场景困难?都不再是问题,让我们的大模型来帮你解决。我们将在近期开放公测申请 ;-) 关注我们,敬请期待! + Time-Series.AI +[^37]: Wang, S., Wu, H., Shi, X., Hu, T., Luo, H., Ma, L., ... & ZHOU, J. (2024). [TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting](https://openreview.net/forum?id=7oLshfEIC2). *ICLR 2024* diff --git a/docs/examples.rst b/docs/examples.rst index d7d6a1e2..5101eba8 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -29,15 +29,13 @@ You can also find a simple and quick-start tutorial notebook on Google Colab # Data preprocessing. Tedious, but PyPOTS can help. 🤓 data = load_specific_dataset('physionet_2012') # PyPOTS will automatically download and extract it. - X = data['X'] - num_samples = len(X['RecordID'].unique()) - X = X.drop(['RecordID', 'Time'], axis = 1) - X = StandardScaler().fit_transform(X.to_numpy()) - X = X.reshape(num_samples, 48, -1) + X = data['train_X'] + num_samples = len(X) + X = StandardScaler().fit_transform(X.reshape(-1, X.shape[-1])).reshape(X.shape) X_ori = X # keep X_ori for validation X = mcar(X, 0.1) # randomly hold out 10% observed values as ground truth dataset = {"X": X} # X for model input - print(X.shape) # (11988, 48, 37), 11988 samples, 48 time steps, 37 features + print(X.shape) # (7671, 48, 37), 7671 samples, 48 time steps, 37 features # initialize the model saits = SAITS( @@ -55,7 +53,7 @@ You can also find a simple and quick-start tutorial notebook on Google Colab model_saving_strategy="best", # only save the model with the best validation performance ) - # train the model. Here I use the whole dataset as the training set, because ground truth is not visible to the model. + # train the model. Here I consider the train dataset only, and evaluate on it, because ground truth is not visible to the model. saits.fit(dataset) # impute the originally-missing values and artificially-missing values imputation = saits.impute(dataset) @@ -64,6 +62,6 @@ You can also find a simple and quick-start tutorial notebook on Google Colab mae = calc_mae(imputation, np.nan_to_num(X_ori), indicating_mask) # calculate mean absolute error on the ground truth (artificially-missing values) # the best model has been already saved, but you can still manually save it with function save_model() as below - saits.save_model(saving_dir="examples/saits",file_name="manually_saved_saits_model") + saits.save(saving_path="examples/saits/manually_saved_saits_model") # you can load the saved model into a new initialized model - saits.load_model("examples/saits/manually_saved_saits_model") + saits.load("examples/saits/manually_saved_saits_model.pypots") diff --git a/docs/index.rst b/docs/index.rst index d6c76d57..c0204b42 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -133,10 +133,12 @@ The paper references are all listed at the bottom of this readme file. +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Type | Algorithm | IMPU | FORE | CLAS | CLUS | ANOD | Year - Venue | +================+===========================================================+======+======+======+======+======+=======================+ -| Neural Net | ImputeFormer :cite:`nie2024imputeformer` | ✅ | | | | | ``2024 - KDD`` | +| Neural Net | TimeMixer :cite:`wang2024timemixer` | ✅ | | | | | ``2024 - ICLR`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | iTransformer🧑‍🔧 :cite:`liu2024itransformer` | ✅ | | | | | ``2024 - ICLR`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ +| Neural Net | ImputeFormer :cite:`nie2024imputeformer` | ✅ | | | | | ``2024 - KDD`` | ++----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | SAITS :cite:`du2023SAITS` | ✅ | | | | | ``2023 - ESWA`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | FreTS🧑‍🔧 :cite:`yi2023frets` | ✅ | | | | | ``2023 - NeurIPS`` | @@ -333,7 +335,7 @@ By committing your code, you'll `pypots/imputation/template `_) to quickly start; 2. become one of `PyPOTS contributors `_ and be listed as a volunteer developer `on the PyPOTS website `_; -3. get mentioned in our `release notes `_; +3. get mentioned in PyPOTS `release notes `_; You can also contribute to PyPOTS by simply staring🌟 this repo to help more people notice it. Your star is your recognition to PyPOTS, and it matters! diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst index d994f96a..8af2d63f 100644 --- a/docs/pypots.imputation.rst +++ b/docs/pypots.imputation.rst @@ -19,6 +19,15 @@ pypots.imputation.transformer :show-inheritance: :inherited-members: +pypots.imputation.timemixer +------------------------------------ + +.. automodule:: pypots.imputation.timemixer + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.imputation.imputeformer ------------------------------------ diff --git a/pypots/data/load_specific_datasets.py b/pypots/data/load_specific_datasets.py index 69e79615..50c6c297 100644 --- a/pypots/data/load_specific_datasets.py +++ b/pypots/data/load_specific_datasets.py @@ -35,7 +35,7 @@ def list_supported_datasets() -> list: def load_specific_dataset(dataset_name: str, use_cache: bool = True) -> dict: """Load specific datasets supported by PyPOTS. - Different from tsdb.load_dataset(), which only produces merely raw data, + Different from tsdb.load(), which only produces merely raw data, load_specific_dataset here does some preprocessing operations, like truncating time series to generate samples with the same length. @@ -45,7 +45,7 @@ def load_specific_dataset(dataset_name: str, use_cache: bool = True) -> dict: The name of the dataset to be loaded, which should be supported, i.e. in SUPPORTED_DATASETS. use_cache : - Whether to use cache. This is an argument of tsdb.load_dataset(). + Whether to use cache. This is an argument of tsdb.load(). Returns ------- diff --git a/tests/imputation/timemixer.py b/tests/imputation/timemixer.py index bf5c72cd..a0735663 100644 --- a/tests/imputation/timemixer.py +++ b/tests/imputation/timemixer.py @@ -47,8 +47,8 @@ class TestTimeMixer(unittest.TestCase): DATA["n_features"], n_layers=2, top_k=5, - d_model=512, - d_ffn=512, + d_model=32, + d_ffn=32, dropout=0.1, epochs=EPOCHS, saving_path=saving_path, From 34b4f4e8dfe76cab150fabc383ed451d3b2e4b0f Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 4 Sep 2024 12:20:42 +0800 Subject: [PATCH 6/6] test: add ModernTCN tests; --- pypots/imputation/__init__.py | 2 + tests/imputation/moderntcn.py | 137 ++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+) create mode 100644 tests/imputation/moderntcn.py diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 4272f36e..8136b0f8 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -37,6 +37,7 @@ from .stemgnn import StemGNN from .imputeformer import ImputeFormer from .timemixer import TimeMixer +from .moderntcn import ModernTCN # naive imputation methods from .locf import LOCF @@ -77,6 +78,7 @@ "StemGNN", "ImputeFormer", "TimeMixer", + "ModernTCN", # naive imputation methods "LOCF", "Mean", diff --git a/tests/imputation/moderntcn.py b/tests/imputation/moderntcn.py new file mode 100644 index 00000000..33b41269 --- /dev/null +++ b/tests/imputation/moderntcn.py @@ -0,0 +1,137 @@ +""" +Test cases for ModernTCN imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import ModernTCN +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + GENERAL_H5_TRAIN_SET_PATH, + GENERAL_H5_VAL_SET_PATH, + GENERAL_H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestModernTCN(unittest.TestCase): + logger.info("Running tests for an imputation model ModernTCN...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "ModernTCN") + model_save_name = "saved_moderntcn_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a ModernTCN model + moderntcn = ModernTCN( + DATA["n_steps"], + DATA["n_features"], + patch_size=3, + patch_stride=2, + downsampling_ratio=2, + ffn_ratio=1, + num_blocks=[1], + large_size=[5], + small_size=[3], + dims=[32], + small_kernel_merged=False, + backbone_dropout=0.1, + head_dropout=0.1, + use_multi_scale=False, + individual=False, + apply_nonstationary_norm=False, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_0_fit(self): + self.moderntcn.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_1_impute(self): + imputation_results = self.moderntcn.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"ModernTCN test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_2_parameters(self): + assert hasattr(self.moderntcn, "model") and self.moderntcn.model is not None + + assert ( + hasattr(self.moderntcn, "optimizer") + and self.moderntcn.optimizer is not None + ) + + assert hasattr(self.moderntcn, "best_loss") + self.assertNotEqual(self.moderntcn.best_loss, float("inf")) + + assert ( + hasattr(self.moderntcn, "best_model_dict") + and self.moderntcn.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.moderntcn) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.moderntcn.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.moderntcn.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_4_lazy_loading(self): + self.moderntcn.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH) + imputation_results = self.moderntcn.predict(GENERAL_H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading ModernTCN test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main()