Skip to content

Commit

Permalink
Merge branch 'dev' into (docs)update
Browse files Browse the repository at this point in the history
# Conflicts:
#	README.md
#	README_zh.md
#	docs/index.rst
#	docs/pypots.imputation.rst
  • Loading branch information
WenjieDu committed Sep 4, 2024
2 parents 518d9d3 + 9fb0719 commit 9248576
Show file tree
Hide file tree
Showing 10 changed files with 1,206 additions and 3 deletions.
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .stemgnn import StemGNN
from .imputeformer import ImputeFormer
from .timemixer import TimeMixer
from .moderntcn import ModernTCN

# naive imputation methods
from .locf import LOCF
Expand Down Expand Up @@ -77,6 +78,7 @@
"StemGNN",
"ImputeFormer",
"TimeMixer",
"ModernTCN",
# naive imputation methods
"LOCF",
"Mean",
Expand Down
24 changes: 24 additions & 0 deletions pypots/imputation/moderntcn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
The package of the partially-observed time-series imputation model ModernTCN.
Refer to the paper
`Donghao Luo, and Xue Wang.
ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis.
In The Twelfth International Conference on Learning Representations. 2024.
<https://openreview.net/pdf?id=vpJMJerXHU>`_
Notes
-----
This implementation is inspired by the official one https://github.com/luodhhh/ModernTCN
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause


from .model import ModernTCN

__all__ = [
"ModernTCN",
]
95 changes: 95 additions & 0 deletions pypots/imputation/moderntcn/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
The core wrapper assembles the submodules of ModernTCN imputation model
and takes over the forward progress of the algorithm.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch.nn as nn

from ...nn.functional import nonstationary_norm, nonstationary_denorm
from ...nn.modules.moderntcn import BackboneModernTCN
from ...nn.modules.patchtst.layers import FlattenHead
from ...utils.metrics import calc_mse


class _ModernTCN(nn.Module):
def __init__(
self,
n_steps,
n_features,
patch_size,
patch_stride,
downsampling_ratio,
ffn_ratio,
num_blocks: list,
large_size: list,
small_size: list,
dims: list,
small_kernel_merged: bool = False,
backbone_dropout: float = 0.1,
head_dropout: float = 0.1,
use_multi_scale: bool = True,
individual: bool = False,
apply_nonstationary_norm: bool = False,
):
super().__init__()

self.apply_nonstationary_norm = apply_nonstationary_norm

self.backbone = BackboneModernTCN(
n_steps,
n_features,
n_features,
patch_size,
patch_stride,
downsampling_ratio,
ffn_ratio,
num_blocks,
large_size,
small_size,
dims,
small_kernel_merged,
backbone_dropout,
head_dropout,
use_multi_scale,
individual,
)

# for the imputation task, the output dim is the same as input dim
self.projection = FlattenHead(
self.backbone.head_nf,
n_steps,
n_features,
head_dropout,
individual,
)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

if self.apply_nonstationary_norm:
# Normalization from Non-stationary Transformer
X, means, stdev = nonstationary_norm(X, missing_mask)

in_X = X.permute(0, 2, 1)
in_X = self.backbone(in_X)
reconstruction = self.projection(in_X)
reconstruction = reconstruction.permute(0, 2, 1)

if self.apply_nonstationary_norm:
# De-Normalization from Non-stationary Transformer
reconstruction = nonstationary_denorm(reconstruction, means, stdev)

imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
loss = calc_mse(reconstruction, inputs["X_ori"], inputs["indicating_mask"])
results["loss"] = loss

return results
24 changes: 24 additions & 0 deletions pypots/imputation/moderntcn/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for ModernTCN.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForModernTCN(DatasetForSAITS):
"""Actually ModernTCN uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading

0 comments on commit 9248576

Please sign in to comment.