Skip to content

Commit

Permalink
feat: implement TimeMixer as an imputation model;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Aug 8, 2024
1 parent 0c68635 commit 449e0fa
Show file tree
Hide file tree
Showing 5 changed files with 476 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .grud import GRUD
from .stemgnn import StemGNN
from .imputeformer import ImputeFormer
from .timemixer import TimeMixer

# naive imputation methods
from .locf import LOCF
Expand Down Expand Up @@ -75,6 +76,7 @@
"GRUD",
"StemGNN",
"ImputeFormer",
"TimeMixer",
# naive imputation methods
"LOCF",
"Mean",
Expand Down
24 changes: 24 additions & 0 deletions pypots/imputation/timemixer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
The package including the modules of TimeMixer.
Refer to the paper
`Shiyu Wang, Haixu Wu, Xiaoming Shi, Tengge Hu, Huakun Luo, Lintao Ma, James Y. Zhang, and Jun Zhou.
"TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting".
In ICLR 2024.
<https://openreview.net/pdf?id=7oLshfEIC2>`_
Notes
-----
This implementation is inspired by the official one https://github.com/kwuking/TimeMixer
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause


from .model import TimeMixer

__all__ = [
"TimeMixer",
]
83 changes: 83 additions & 0 deletions pypots/imputation/timemixer/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch.nn as nn

from ...nn.functional import (
nonstationary_norm,
nonstationary_denorm,
)
from ...nn.modules.timemixer import BackboneTimeMixer
from ...utils.metrics import calc_mse


class _TimeMixer(nn.Module):
def __init__(
self,
n_layers,
n_steps,
n_features,
d_model,
d_ffn,
dropout,
top_k,
channel_independence,
decomp_method,
moving_avg,
downsampling_layers,
downsampling_window,
apply_nonstationary_norm: bool = False,
):
super().__init__()

self.apply_nonstationary_norm = apply_nonstationary_norm

self.model = BackboneTimeMixer(
task_name="imputation",
n_steps=n_steps,
n_features=n_features,
n_pred_steps=None,
n_pred_features=n_features,
n_layers=n_layers,
d_model=d_model,
d_ffn=d_ffn,
dropout=dropout,
channel_independence=channel_independence,
decomp_method=decomp_method,
top_k=top_k,
moving_avg=moving_avg,
downsampling_layers=downsampling_layers,
downsampling_window=downsampling_window,
downsampling_method="avg",
use_future_temporal_feature=False,
)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

if self.apply_nonstationary_norm:
# Normalization from Non-stationary Transformer
X, means, stdev = nonstationary_norm(X, missing_mask)

# TimesMixer processing
dec_out = self.model.imputation(X, None)

if self.apply_nonstationary_norm:
# De-Normalization from Non-stationary Transformer
dec_out = nonstationary_denorm(dec_out, means, stdev)

imputed_data = missing_mask * X + (1 - missing_mask) * dec_out
results = {
"imputed_data": imputed_data,
}

if training:
# `loss` is always the item for backward propagating to update the model
loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"])
results["loss"] = loss

return results
24 changes: 24 additions & 0 deletions pypots/imputation/timemixer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for the imputation model TimeMixer.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForTimeMixer(DatasetForSAITS):
"""Actually TimeMixer uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading

0 comments on commit 449e0fa

Please sign in to comment.