diff --git a/README.md b/README.md
index 781404b9..a7c7160d 100644
--- a/README.md
+++ b/README.md
@@ -162,6 +162,7 @@ PyPOTS supports imputation, classification, clustering, and forecasting tasks on
| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
| Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 |
| Neural Net | Transformer | Attention is All you Need [^2];
Self-Attention-based Imputation for Time Series [^1];
Note: proposed in [^2], and re-implemented as an imputation model in [^1]. | 2017 |
+| Neural Net | CSDI | Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation [^12] | 2021 |
| Neural Net | US-GAN | Generative Semi-supervised Learning for Multivariate Time Series Imputation [^10] | 2021 |
| Neural Net | GP-VAE | GP-VAE: Deep Probabilistic Time Series Imputation [^11] | 2020 |
| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 |
@@ -284,7 +285,9 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together
[^8]: Chen, X., & Sun, L. (2021). [Bayesian Temporal Factorization for Multidimensional Time Series Prediction](https://arxiv.org/abs/1910.06366). *IEEE transactions on pattern analysis and machine intelligence*.
[^9]: Yoon, J., Zame, W. R., & van der Schaar, M. (2019). [Estimating Missing Data in Temporal Data Streams Using Multi-Directional Recurrent Neural Networks](https://ieeexplore.ieee.org/document/8485748). *IEEE Transactions on Biomedical Engineering*.
[^10]: Miao, X., Wu, Y., Wang, J., Gao, Y., Mao, X., & Yin, J. (2021). [Generative Semi-supervised Learning for Multivariate Time Series Imputation](https://ojs.aaai.org/index.php/AAAI/article/view/17086). *AAAI 2021*.
-[^11]: Fortuin, V., Baranchuk, D., Raetsch, G. & Mandt, S.. (2020). [GP-VAE: Deep Probabilistic Time Series Imputation](https://proceedings.mlr.press/v108/fortuin20a.html). *AISTATS 2020*.
+[^11]: Fortuin, V., Baranchuk, D., Raetsch, G. & Mandt, S. (2020). [GP-VAE: Deep Probabilistic Time Series Imputation](https://proceedings.mlr.press/v108/fortuin20a.html). *AISTATS 2020*.
+[^12]: Tashiro, Y., Song, J., Song, Y., & Ermon, S. (2021). [CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation](https://proceedings.neurips.cc/paper/2021/hash/cfe8504bda37b575c70ee1a8276f3486-Abstract.html). *NeurIPS 2021*.
+
🏠 Visits
diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst
index 616fc97f..9d475370 100644
--- a/docs/pypots.imputation.rst
+++ b/docs/pypots.imputation.rst
@@ -19,6 +19,15 @@ pypots.imputation.transformer
:show-inheritance:
:inherited-members:
+pypots.imputation.csdi
+------------------------------
+
+.. automodule:: pypots.imputation.csdi
+ :members:
+ :undoc-members:
+ :show-inheritance:
+ :inherited-members:
+
pypots.imputation.usgan
------------------------------
diff --git a/docs/pypots.modules.rst b/docs/pypots.modules.rst
new file mode 100644
index 00000000..374ad449
--- /dev/null
+++ b/docs/pypots.modules.rst
@@ -0,0 +1,14 @@
+pypots.modules package
+======================
+
+pypots.modules.rnn
+------------------
+
+.. automodule:: pypots.modules.rnn
+ :members:
+
+pypots.modules.self_attention
+-----------------------------
+
+.. automodule:: pypots.modules.self_attention
+ :members:
diff --git a/docs/pypots.rst b/docs/pypots.rst
index 13434e4b..67ae27eb 100644
--- a/docs/pypots.rst
+++ b/docs/pypots.rst
@@ -11,6 +11,7 @@ Subpackages
pypots.classification
pypots.clustering
pypots.forecasting
+ pypots.modules
pypots.optim
pypots.data
pypots.utils
diff --git a/docs/references.bib b/docs/references.bib
index 8aa37c62..687a9885 100644
--- a/docs/references.bib
+++ b/docs/references.bib
@@ -336,20 +336,6 @@ @article{tang2019JointModeling
keywords = {Computer Science - Machine Learning,Statistics - Machine Learning}
}
-@article{tashiro2021CSDI,
-title = {{{CSDI}}: {{Conditional Score-based Diffusion Models}} for {{Probabilistic Time Series Imputation}}},
-author = {Tashiro, Yusuke and Song, Jiaming and Song, Yang and Ermon, Stefano},
-year = {2021},
-month = oct,
-journal = {arXiv:2107.03502 [cs, stat]},
-eprint = {2107.03502},
-eprinttype = {arxiv},
-primaryclass = {cs, stat},
-url = {http://arxiv.org/abs/2107.03502},
-archiveprefix = {arXiv},
-keywords = {Computer Science - Machine Learning,Statistics - Machine Learning}
-}
-
@inproceedings{vaswani2017Transformer,
author = {Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and Kaiser, \L ukasz and Polosukhin, Illia},
booktitle = {Advances in Neural Information Processing Systems},
@@ -449,4 +435,13 @@ @article{calinski1974
pages={1--27},
year={1974},
publisher={Taylor \& Francis}
-}
\ No newline at end of file
+}
+
+@inproceedings{tashiro2021csdi,
+title={{CSDI}: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation},
+author={YUSUKE TASHIRO and Jiaming Song and Yang Song and Stefano Ermon},
+booktitle={Advances in Neural Information Processing Systems},
+editor={A. Beygelzimer and Y. Dauphin and P. Liang and J. Wortman Vaughan},
+year={2021},
+url={https://openreview.net/forum?id=VzuIzbRDrum}
+}
diff --git a/pypots/classification/grud/modules/__init__.py b/pypots/classification/grud/modules/__init__.py
index 22cb7b77..49e53174 100644
--- a/pypots/classification/grud/modules/__init__.py
+++ b/pypots/classification/grud/modules/__init__.py
@@ -6,7 +6,7 @@
# License: GLP-v3
from .core import _GRUD
-from .submodules import TemporalDecay
+from pypots.modules.rnn import TemporalDecay
__all__ = [
"_GRUD",
diff --git a/pypots/classification/grud/modules/core.py b/pypots/classification/grud/modules/core.py
index 92326719..6e9aed08 100644
--- a/pypots/classification/grud/modules/core.py
+++ b/pypots/classification/grud/modules/core.py
@@ -16,7 +16,7 @@
import torch.nn as nn
import torch.nn.functional as F
-from .submodules import TemporalDecay
+from pypots.modules.rnn import TemporalDecay
class _GRUD(nn.Module):
diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py
index a6c4dcd8..d39bff1b 100644
--- a/pypots/imputation/__init__.py
+++ b/pypots/imputation/__init__.py
@@ -12,6 +12,7 @@
from .saits import SAITS
from .transformer import Transformer
from .usgan import USGAN
+from .csdi import CSDI
__all__ = [
"SAITS",
@@ -21,4 +22,5 @@
"LOCF",
"GPVAE",
"USGAN",
+ "CSDI",
]
diff --git a/pypots/imputation/brits/modules/core.py b/pypots/imputation/brits/modules/core.py
index e5c29698..facf3129 100644
--- a/pypots/imputation/brits/modules/core.py
+++ b/pypots/imputation/brits/modules/core.py
@@ -20,7 +20,7 @@
import torch.nn as nn
from .submodules import FeatureRegression
-from ....classification.grud.modules import TemporalDecay
+from ....modules.rnn import TemporalDecay
from ....utils.metrics import cal_mae
diff --git a/pypots/imputation/csdi/__init__.py b/pypots/imputation/csdi/__init__.py
new file mode 100644
index 00000000..70381d34
--- /dev/null
+++ b/pypots/imputation/csdi/__init__.py
@@ -0,0 +1,12 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: GLP-v3
+
+from .model import CSDI
+
+__all__ = [
+ "CSDI",
+]
diff --git a/pypots/imputation/csdi/data.py b/pypots/imputation/csdi/data.py
new file mode 100644
index 00000000..309c97a0
--- /dev/null
+++ b/pypots/imputation/csdi/data.py
@@ -0,0 +1,152 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: GLP-v3
+
+from typing import Union, Iterable
+
+import torch
+from pycorruptor import mcar
+
+from ...data.base import BaseDataset
+
+
+class DatasetForCSDI(BaseDataset):
+ """Dataset for CSDI model."""
+
+ def __init__(
+ self,
+ data: Union[dict, str],
+ return_labels: bool = True,
+ file_type: str = "h5py",
+ rate: float = 0.1,
+ ):
+ super().__init__(data, return_labels, file_type)
+ self.time_points = (
+ None if "time_points" not in data.keys() else data["time_points"]
+ )
+ # _, self.time_points = self._check_input(self.X, time_points)
+ self.for_pattern_mask = (
+ None if "for_pattern_mask" not in data.keys() else data["for_pattern_mask"]
+ )
+ # _, self.for_pattern_mask = self._check_input(self.X, for_pattern_mask)
+ self.cut_length = (
+ None if "cut_length" not in data.keys() else data["cut_length"]
+ )
+ # _, self.cut_length = self._check_input(self.X, cut_length)
+ self.rate = rate
+
+ def _fetch_data_from_array(self, idx: int) -> Iterable:
+ """Fetch data according to index.
+
+ Parameters
+ ----------
+ idx : int,
+ The index to fetch the specified sample.
+
+ Returns
+ -------
+ sample : list,
+ A list contains
+
+ index : int tensor,
+ The index of the sample.
+
+ X_intact : tensor,
+ Original time-series for calculating mask imputation loss.
+
+ X : tensor,
+ Time-series data with artificially missing values for model input.
+
+ missing_mask : tensor,
+ The mask records all missing values in X.
+
+ indicating_mask : tensor.
+ The mask indicates artificially missing values in X.
+ """
+ X = self.X[idx].to(torch.float32)
+ X_intact, X, missing_mask, indicating_mask = mcar(X, rate=self.rate)
+
+ observed_data = X_intact
+ observed_mask = missing_mask + indicating_mask
+ observed_tp = (
+ torch.arange(0, self.n_steps, dtype=torch.float32)
+ if self.time_points is None
+ else self.time_points[idx].to(torch.float32)
+ )
+ gt_mask = indicating_mask
+ for_pattern_mask = (
+ gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx]
+ )
+ cut_length = (
+ torch.zeros(len(observed_data)).long()
+ if self.cut_length is None
+ else self.cut_length[idx]
+ )
+
+ sample = [
+ torch.tensor(idx),
+ observed_data,
+ observed_mask,
+ observed_tp,
+ gt_mask,
+ for_pattern_mask,
+ cut_length,
+ ]
+
+ if self.y is not None and self.return_labels:
+ sample.append(self.y[idx].to(torch.long))
+
+ return sample
+
+ def _fetch_data_from_file(self, idx: int) -> Iterable:
+ """Fetch data with the lazy-loading strategy, i.e. only loading data from the file while requesting for samples.
+ Here the opened file handle doesn't load the entire dataset into RAM but only load the currently accessed slice.
+
+ Parameters
+ ----------
+ idx : int,
+ The index of the sample to be return.
+
+ Returns
+ -------
+ sample : list,
+ The collated data sample, a list including all necessary sample info.
+ """
+
+ if self.file_handle is None:
+ self.file_handle = self._open_file_handle()
+
+ X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32)
+ X_intact, X, missing_mask, indicating_mask = mcar(X, rate=self.rate)
+
+ observed_data = X_intact
+ observed_mask = missing_mask + indicating_mask
+ observed_tp = self.time_points[idx].to(torch.float32)
+ gt_mask = indicating_mask
+ for_pattern_mask = (
+ gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx]
+ )
+ cut_length = (
+ torch.zeros(len(observed_data)).long()
+ if self.cut_length is None
+ else self.cut_length[idx]
+ )
+
+ sample = [
+ torch.tensor(idx),
+ observed_data,
+ observed_mask,
+ observed_tp,
+ gt_mask,
+ for_pattern_mask,
+ cut_length,
+ ]
+
+ # if the dataset has labels and is for training, then fetch it from the file
+ if "y" in self.file_handle.keys() and self.return_labels:
+ sample.append(torch.tensor(self.file_handle["y"][idx], dtype=torch.long))
+
+ return sample
diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py
new file mode 100644
index 00000000..3e39e5dd
--- /dev/null
+++ b/pypots/imputation/csdi/model.py
@@ -0,0 +1,313 @@
+"""
+The implementation of CSDI for the partially-observed time-series imputation task.
+
+Refer to the paper Tashiro, Y., Song, J., Song, Y., & Ermon, S. (2021).
+CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation. NeurIPS 2021.
+
+Notes
+-----
+Partial implementation uses code from the official implementation https://github.com/ermongroup/CSDI.
+
+"""
+
+# Created by Wenjie Du
+# License: GPL-v3
+
+from typing import Union, Optional
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+from .data import DatasetForCSDI
+from .modules import _CSDI
+from ..base import BaseNNImputer
+from ...optim.adam import Adam
+from ...optim.base import Optimizer
+from ...utils.logging import logger
+
+
+class CSDI(BaseNNImputer):
+ """The PyTorch implementation of the CSDI model :cite:`tashiro2021csdi`.
+
+ Parameters
+ ----------
+ n_features :
+ The number of features in the time-series data sample.
+
+ n_layers :
+ The number of layers in the 1st and 2nd DMSA blocks in the SAITS model.
+
+ n_heads :
+ The number of heads in the multi-head attention mechanism.
+
+ n_channels :
+ The number of residual channels.
+
+ d_time_embedding :
+ The dimension number of the time (temporal) embedding.
+
+ d_feature_embedding :
+ The dimension number of the feature embedding.
+
+ d_diffusion_embedding :
+ The dimension number of the diffusion embedding.
+
+ is_unconditional :
+ Whether the model is unconditional or conditional.
+
+ target_strategy :
+ The strategy for selecting the target for the diffusion process. It has to be one of ["mix", "random"].
+
+ n_diffusion_steps :
+ The number of the diffusion step T in the original paper.
+
+ schedule:
+ The schedule for other noise levels. It has to be one of ["quad", "linear"].
+
+ beta_start:
+ The minimum noise level.
+
+ beta_end:
+ The maximum noise level.
+
+ batch_size :
+ The batch size for training and evaluating the model.
+
+ epochs :
+ The number of epochs for training the model.
+
+ patience :
+ The patience for the early-stopping mechanism. Given a positive integer, the training process will be
+ stopped when the model does not perform better after that number of epochs.
+ Leaving it default as None will disable the early-stopping.
+
+ optimizer :
+ The optimizer for model training.
+ If not given, will use a default Adam optimizer.
+
+ num_workers :
+ The number of subprocesses to use for data loading.
+ `0` means data loading will be in the main process, i.e. there won't be subprocesses.
+
+ device :
+ The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
+ If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
+ then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
+ If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
+ model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
+ Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
+
+ saving_path :
+ The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
+ training into a tensorboard file). Will not save if not given.
+
+ model_saving_strategy :
+ The strategy to save model checkpoints. It has to be one of [None, "best", "better"].
+ No model will be saved when it is set as None.
+ The "best" strategy will only automatically save the best model after the training finished.
+ The "better" strategy will automatically save the model during training whenever the model performs
+ better than in previous epochs.
+
+ References
+ ----------
+ .. [1] `Yusuke Tashiro, Jiaming Song, Yang Song, Stefano Ermon.
+ "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation".
+ NeurIPS 2021.
+ `_
+
+ """
+
+ def __init__(
+ self,
+ n_features: int,
+ n_layers: int,
+ n_heads: int,
+ n_channels: int,
+ d_time_embedding: int,
+ d_feature_embedding: int,
+ d_diffusion_embedding: int,
+ is_unconditional: bool = False,
+ target_strategy: str = "random",
+ n_diffusion_steps: int = 50,
+ schedule: str = "quad",
+ beta_start: float = 0.0001,
+ beta_end: float = 0.5,
+ batch_size: int = 32,
+ epochs: int = 100,
+ patience: Optional[int] = None,
+ optimizer: Optional[Optimizer] = Adam(),
+ num_workers: int = 0,
+ device: Optional[Union[str, torch.device, list]] = None,
+ saving_path: Optional[str] = None,
+ model_saving_strategy: Optional[str] = "best",
+ ):
+ super().__init__(
+ batch_size,
+ epochs,
+ patience,
+ num_workers,
+ device,
+ saving_path,
+ model_saving_strategy,
+ )
+ assert target_strategy in ["mix", "random"]
+ assert schedule in ["quad", "linear"]
+
+ # set up the model
+ self.model = _CSDI(
+ n_layers,
+ n_heads,
+ n_channels,
+ n_features,
+ d_time_embedding,
+ d_feature_embedding,
+ d_diffusion_embedding,
+ is_unconditional,
+ target_strategy,
+ n_diffusion_steps,
+ schedule,
+ beta_start,
+ beta_end,
+ )
+ self._print_model_size()
+ self._send_model_to_given_device()
+
+ # set up the optimizer
+ self.optimizer = optimizer
+ self.optimizer.init_optimizer(self.model.parameters())
+
+ def _assemble_input_for_training(self, data: list) -> dict:
+ (
+ indices,
+ observed_data,
+ observed_mask,
+ observed_tp,
+ gt_mask,
+ for_pattern_mask,
+ cut_length,
+ ) = self._send_data_to_given_device(data)
+
+ inputs = {
+ "observed_data": observed_data.permute(0, 2, 1),
+ "observed_mask": observed_mask.permute(0, 2, 1),
+ "observed_tp": observed_tp,
+ "gt_mask": gt_mask.permute(0, 2, 1),
+ "for_pattern_mask": for_pattern_mask,
+ "cut_length": cut_length,
+ }
+ return inputs
+
+ def _assemble_input_for_validating(self, data) -> dict:
+ return self._assemble_input_for_training(data)
+
+ def _assemble_input_for_testing(self, data) -> dict:
+ return self._assemble_input_for_validating(data)
+
+ def fit(
+ self,
+ train_set: Union[dict, str],
+ val_set: Optional[Union[dict, str]] = None,
+ file_type: str = "h5py",
+ n_sampling_times: int = 1,
+ ) -> None:
+ # Step 1: wrap the input data with classes Dataset and DataLoader
+ training_set = DatasetForCSDI(
+ train_set, return_labels=False, file_type=file_type
+ )
+ training_loader = DataLoader(
+ training_set,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ )
+ val_loader = None
+ if val_set is not None:
+ val_set = DatasetForCSDI(val_set, return_labels=False, file_type=file_type)
+ val_loader = DataLoader(
+ val_set,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ )
+
+ # Step 2: train the model and freeze it
+ self._train_model(training_loader, val_loader)
+ self.model.load_state_dict(self.best_model_dict)
+ self.model.eval() # set the model as eval status to freeze it.
+
+ # Step 3: save the model if necessary
+ self._auto_save_model_if_necessary(training_finished=True)
+
+ def predict(
+ self,
+ test_set: Union[dict, str],
+ file_type: str = "h5py",
+ n_sampling_times: int = 1,
+ ) -> dict:
+ """
+
+ Parameters
+ ----------
+ test_set : dict or str
+ The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
+ or a path string locating a data file.
+ If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
+ which is time-series data for validating, can contain missing values, and y should be array-like of shape
+ [n_samples], which is classification labels of X.
+ If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
+ key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
+
+ file_type : str
+ The type of the given file if test_set is a path string.
+
+ n_sampling_times:
+ The number of sampling times for the model to sample from the diffusion process.
+
+ Returns
+ -------
+ result_dict: dict
+ Prediction results in a Python Dictionary for the given samples.
+ It should be a dictionary including a key named 'imputation'.
+
+ """
+ # Step 1: wrap the input data with classes Dataset and DataLoader
+ self.model.eval() # set the model as eval status to freeze it.
+ test_set = DatasetForCSDI(test_set, return_labels=False, file_type=file_type)
+ test_loader = DataLoader(
+ test_set,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ )
+ imputation_collector = []
+
+ # Step 2: process the data with the model
+ with torch.no_grad():
+ for idx, data in enumerate(test_loader):
+ inputs = self._assemble_input_for_testing(data)
+ results = self.model(
+ inputs,
+ training=False,
+ n_sampling_times=n_sampling_times,
+ )
+ imputed_data = results["imputed_data"]
+ imputation_collector.append(imputed_data)
+
+ # Step 3: output collection and return
+ imputation = torch.cat(imputation_collector).cpu().detach().numpy()
+ result_dict = {
+ "imputation": imputation,
+ }
+ return result_dict
+
+ def impute(
+ self,
+ X: Union[dict, str],
+ file_type="h5py",
+ ) -> np.ndarray:
+ logger.warning(
+ "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
+ )
+ results_dict = self.predict(X, file_type=file_type)
+ return results_dict["imputation"]
diff --git a/pypots/imputation/csdi/modules/__init__.py b/pypots/imputation/csdi/modules/__init__.py
new file mode 100644
index 00000000..fb164182
--- /dev/null
+++ b/pypots/imputation/csdi/modules/__init__.py
@@ -0,0 +1,12 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: GLP-v3
+
+from .core import _CSDI
+
+__all__ = [
+ "_CSDI",
+]
diff --git a/pypots/imputation/csdi/modules/core.py b/pypots/imputation/csdi/modules/core.py
new file mode 100644
index 00000000..54071943
--- /dev/null
+++ b/pypots/imputation/csdi/modules/core.py
@@ -0,0 +1,263 @@
+# Created by Wenjie Du
+# License: GLP-v3
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from .submodules import DiffusionModel
+
+
+class _CSDI(nn.Module):
+ def __init__(
+ self,
+ n_layers,
+ n_heads,
+ n_channels,
+ d_target,
+ d_time_embedding,
+ d_feature_embedding,
+ d_diffusion_embedding,
+ is_unconditional,
+ target_strategy,
+ n_diffusion_steps,
+ schedule,
+ beta_start,
+ beta_end,
+ ):
+ super().__init__()
+
+ self.d_target = d_target
+ self.d_time_embedding = d_time_embedding
+ self.d_feature_embedding = d_feature_embedding
+ self.is_unconditional = is_unconditional
+ self.target_strategy = target_strategy
+ self.n_channels = n_channels
+ self.n_diffusion_steps = n_diffusion_steps
+
+ d_side = d_time_embedding + d_feature_embedding
+ if self.is_unconditional:
+ d_input = 1
+ else:
+ d_side += 1 # for conditional mask
+ d_input = 2
+
+ self.embed_layer = nn.Embedding(
+ num_embeddings=self.d_target,
+ embedding_dim=self.d_feature_embedding,
+ )
+
+ self.diff_model = DiffusionModel(
+ n_diffusion_steps,
+ d_diffusion_embedding,
+ d_input,
+ d_side,
+ n_channels,
+ n_heads,
+ n_layers,
+ )
+
+ # parameters for diffusion models
+ if schedule == "quad":
+ self.beta = (
+ np.linspace(beta_start**0.5, beta_end**0.5, self.n_diffusion_steps)
+ ** 2
+ )
+ elif schedule == "linear":
+ self.beta = np.linspace(beta_start, beta_end, self.n_diffusion_steps)
+
+ self.alpha_hat = 1 - self.beta
+ self.alpha = np.cumprod(self.alpha_hat)
+ self.register_buffer(
+ "alpha_torch", torch.tensor(self.alpha).float().unsqueeze(1).unsqueeze(1)
+ )
+
+ def time_embedding(self, pos, d_model=128):
+ pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(pos.device)
+ position = pos.unsqueeze(2)
+ div_term = 1 / torch.pow(
+ 10000.0, torch.arange(0, d_model, 2, device=pos.device) / d_model
+ )
+ pe[:, :, 0::2] = torch.sin(position * div_term)
+ pe[:, :, 1::2] = torch.cos(position * div_term)
+ return pe
+
+ def get_randmask(self, observed_mask):
+ rand_for_mask = torch.rand_like(observed_mask) * observed_mask
+ rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1)
+ for i in range(len(observed_mask)):
+ sample_ratio = np.random.rand() # missing ratio
+ num_observed = observed_mask[i].sum().item()
+ num_masked = round(num_observed * sample_ratio)
+ rand_for_mask[i][rand_for_mask[i].topk(num_masked).indices] = -1
+ cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float()
+ return cond_mask
+
+ def get_hist_mask(self, observed_mask, for_pattern_mask=None):
+ if for_pattern_mask is None:
+ for_pattern_mask = observed_mask
+ if self.target_strategy == "mix":
+ rand_mask = self.get_randmask(observed_mask)
+
+ cond_mask = observed_mask.clone()
+ for i in range(len(cond_mask)):
+ mask_choice = np.random.rand()
+ if self.target_strategy == "mix" and mask_choice > 0.5:
+ cond_mask[i] = rand_mask[i]
+ else: # draw another sample for histmask (i-1 corresponds to another sample)
+ cond_mask[i] = cond_mask[i] * for_pattern_mask[i - 1]
+ return cond_mask
+
+ def get_side_info(self, observed_tp, cond_mask):
+ B, K, L = cond_mask.shape
+ device = observed_tp.device
+ time_embed = self.time_embedding(
+ observed_tp, self.d_time_embedding
+ ) # (B,L,emb)
+ time_embed = time_embed.to(device)
+ time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1)
+ feature_embed = self.embed_layer(
+ torch.arange(self.d_target).to(device)
+ ) # (K,emb)
+ feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1)
+
+ side_info = torch.cat([time_embed, feature_embed], dim=-1) # (B,L,K,*)
+ side_info = side_info.permute(0, 3, 2, 1) # (B,*,K,L)
+
+ if not self.is_unconditional:
+ side_mask = cond_mask.unsqueeze(1) # (B,1,K,L)
+ side_info = torch.cat([side_info, side_mask], dim=1)
+
+ return side_info
+
+ def calc_loss_valid(
+ self, observed_data, cond_mask, observed_mask, side_info, is_train
+ ):
+ loss_sum = 0
+ for t in range(self.n_diffusion_steps): # calculate loss for all t
+ loss = self.calc_loss(
+ observed_data, cond_mask, observed_mask, side_info, is_train, set_t=t
+ )
+ loss_sum += loss.detach()
+ return loss_sum / self.n_diffusion_steps
+
+ def calc_loss(
+ self, observed_data, cond_mask, observed_mask, side_info, is_train, set_t=-1
+ ):
+ B, K, L = observed_data.shape
+ device = observed_data.device
+ if is_train != 1: # for validation
+ t = (torch.ones(B) * set_t).long().to(device)
+ else:
+ t = torch.randint(0, self.n_diffusion_steps, [B]).to(device)
+
+ current_alpha = self.alpha_torch[t] # (B,1,1)
+ noise = torch.randn_like(observed_data)
+ noisy_data = (current_alpha**0.5) * observed_data + (
+ 1.0 - current_alpha
+ ) ** 0.5 * noise
+
+ total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask)
+
+ predicted = self.diff_model(total_input, side_info, t) # (B,K,L)
+
+ target_mask = observed_mask - cond_mask
+ residual = (noise - predicted) * target_mask
+ num_eval = target_mask.sum()
+ loss = (residual**2).sum() / (num_eval if num_eval > 0 else 1)
+ return loss
+
+ def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask):
+ if self.is_unconditional:
+ total_input = noisy_data.unsqueeze(1) # (B,1,K,L)
+ else:
+ cond_obs = (cond_mask * observed_data).unsqueeze(1)
+ noisy_target = ((1 - cond_mask) * noisy_data).unsqueeze(1)
+ total_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L)
+
+ return total_input
+
+ def impute(self, observed_data, cond_mask, side_info, n_sampling_times):
+ B, K, L = observed_data.shape
+ device = observed_data.device
+ imputed_samples = torch.zeros(B, n_sampling_times, K, L).to(device)
+
+ for i in range(n_sampling_times):
+ # generate noisy observation for unconditional model
+ if self.is_unconditional:
+ noisy_obs = observed_data
+ noisy_cond_history = []
+ for t in range(self.n_diffusion_steps):
+ noise = torch.randn_like(noisy_obs)
+ noisy_obs = (self.alpha_hat[t] ** 0.5) * noisy_obs + self.beta[
+ t
+ ] ** 0.5 * noise
+ noisy_cond_history.append(noisy_obs * cond_mask)
+
+ current_sample = torch.randn_like(observed_data)
+
+ for t in range(self.n_diffusion_steps - 1, -1, -1):
+ if self.is_unconditional:
+ diff_input = (
+ cond_mask * noisy_cond_history[t]
+ + (1.0 - cond_mask) * current_sample
+ )
+ diff_input = diff_input.unsqueeze(1) # (B,1,K,L)
+ else:
+ cond_obs = (cond_mask * observed_data).unsqueeze(1)
+ noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1)
+ diff_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L)
+ predicted = self.diff_model(
+ diff_input, side_info, torch.tensor([t]).to(device)
+ )
+
+ coeff1 = 1 / self.alpha_hat[t] ** 0.5
+ coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5
+ current_sample = coeff1 * (current_sample - coeff2 * predicted)
+
+ if t > 0:
+ noise = torch.randn_like(current_sample)
+ sigma = (
+ (1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t]
+ ) ** 0.5
+ current_sample += sigma * noise
+
+ imputed_samples[:, i] = current_sample.detach()
+ return imputed_samples
+
+ def forward(self, inputs, training=True, n_sampling_times=1):
+ (observed_data, observed_mask, observed_tp, gt_mask, for_pattern_mask,) = (
+ inputs["observed_data"],
+ inputs["observed_mask"],
+ inputs["observed_tp"],
+ inputs["gt_mask"],
+ inputs["for_pattern_mask"],
+ )
+
+ if not training:
+ cond_mask = gt_mask
+ elif self.target_strategy != "random":
+ cond_mask = self.get_hist_mask(
+ observed_mask, for_pattern_mask=for_pattern_mask
+ )
+ else:
+ cond_mask = self.get_randmask(observed_mask)
+
+ side_info = self.get_side_info(observed_tp, cond_mask)
+
+ loss_func = self.calc_loss if training == 1 else self.calc_loss_valid
+
+ # `loss` is always the item for backward propagating to update the model
+ loss = loss_func(observed_data, cond_mask, observed_mask, side_info, training)
+
+ results = {
+ "loss": loss, # will be used for backward propagating to update the model
+ }
+ if not training:
+ samples = self.impute(
+ observed_data, cond_mask, side_info, n_sampling_times
+ ) # (B,nsample,K,L)
+ imputation = samples.mean(dim=1) # (B,K,L)
+ imputed_data = observed_data + imputation * (1 - gt_mask)
+ results["imputed_data"] = imputed_data.permute(0, 2, 1) # (B,L,K)
+ return results
diff --git a/pypots/imputation/csdi/modules/submodules.py b/pypots/imputation/csdi/modules/submodules.py
new file mode 100644
index 00000000..e9739bdc
--- /dev/null
+++ b/pypots/imputation/csdi/modules/submodules.py
@@ -0,0 +1,177 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: GLP-v3
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def get_torch_trans(heads=8, layers=1, channels=64):
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu"
+ )
+ return nn.TransformerEncoder(encoder_layer, num_layers=layers)
+
+
+def Conv1d_with_init(in_channels, out_channels, kernel_size):
+ layer = nn.Conv1d(in_channels, out_channels, kernel_size)
+ nn.init.kaiming_normal_(layer.weight)
+ return layer
+
+
+class DiffusionEmbedding(nn.Module):
+ def __init__(self, n_diffusion_steps, d_embedding=128, d_projection=None):
+ super().__init__()
+ if d_projection is None:
+ d_projection = d_embedding
+ self.register_buffer(
+ "embedding",
+ self._build_embedding(n_diffusion_steps, d_embedding // 2),
+ persistent=False,
+ )
+ self.projection1 = nn.Linear(d_embedding, d_projection)
+ self.projection2 = nn.Linear(d_projection, d_projection)
+
+ def forward(self, diffusion_step):
+ x = self.embedding[diffusion_step]
+ x = self.projection1(x)
+ x = F.silu(x)
+ x = self.projection2(x)
+ x = F.silu(x)
+ return x
+
+ def _build_embedding(self, n_steps, d_embedding=64):
+ steps = torch.arange(n_steps).unsqueeze(1) # (T,1)
+ frequencies = 10.0 ** (
+ torch.arange(d_embedding) / (d_embedding - 1) * 4.0
+ ).unsqueeze(
+ 0
+ ) # (1,dim)
+ table = steps * frequencies # (T,dim)
+ table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) # (T,dim*2)
+ return table
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, d_side, n_channels, diffusion_embedding_dim, nheads):
+ super().__init__()
+ self.diffusion_projection = nn.Linear(diffusion_embedding_dim, n_channels)
+ self.cond_projection = Conv1d_with_init(d_side, 2 * n_channels, 1)
+ self.mid_projection = Conv1d_with_init(n_channels, 2 * n_channels, 1)
+ self.output_projection = Conv1d_with_init(n_channels, 2 * n_channels, 1)
+
+ self.time_layer = get_torch_trans(heads=nheads, layers=1, channels=n_channels)
+ self.feature_layer = get_torch_trans(
+ heads=nheads, layers=1, channels=n_channels
+ )
+
+ def forward_time(self, y, base_shape):
+ B, channel, K, L = base_shape
+ if L == 1:
+ return y
+ y = y.reshape(B, channel, K, L).permute(0, 2, 1, 3).reshape(B * K, channel, L)
+ y = self.time_layer(y.permute(2, 0, 1)).permute(1, 2, 0)
+ y = y.reshape(B, K, channel, L).permute(0, 2, 1, 3).reshape(B, channel, K * L)
+ return y
+
+ def forward_feature(self, y, base_shape):
+ B, channel, K, L = base_shape
+ if K == 1:
+ return y
+ y = y.reshape(B, channel, K, L).permute(0, 3, 1, 2).reshape(B * L, channel, K)
+ y = self.feature_layer(y.permute(2, 0, 1)).permute(1, 2, 0)
+ y = y.reshape(B, L, channel, K).permute(0, 2, 3, 1).reshape(B, channel, K * L)
+ return y
+
+ def forward(self, x, cond_info, diffusion_emb):
+ B, channel, K, L = x.shape
+ base_shape = x.shape
+ x = x.reshape(B, channel, K * L)
+
+ diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze(
+ -1
+ ) # (B,channel,1)
+ y = x + diffusion_emb
+
+ y = self.forward_time(y, base_shape)
+ y = self.forward_feature(y, base_shape) # (B,channel,K*L)
+ y = self.mid_projection(y) # (B,2*channel,K*L)
+
+ _, cond_dim, _, _ = cond_info.shape
+ cond_info = cond_info.reshape(B, cond_dim, K * L)
+ cond_info = self.cond_projection(cond_info) # (B,2*channel,K*L)
+ y = y + cond_info
+
+ gate, filter = torch.chunk(y, 2, dim=1)
+ y = torch.sigmoid(gate) * torch.tanh(filter) # (B,channel,K*L)
+ y = self.output_projection(y)
+
+ residual, skip = torch.chunk(y, 2, dim=1)
+ x = x.reshape(base_shape)
+ residual = residual.reshape(base_shape)
+ skip = skip.reshape(base_shape)
+ return (x + residual) / math.sqrt(2.0), skip
+
+
+class DiffusionModel(nn.Module):
+ def __init__(
+ self,
+ n_diffusion_steps,
+ d_diffusion_embedding,
+ d_input,
+ d_side,
+ n_channels,
+ n_heads,
+ n_layers,
+ ):
+ super().__init__()
+ self.diffusion_embedding = DiffusionEmbedding(
+ n_diffusion_steps=n_diffusion_steps,
+ d_embedding=d_diffusion_embedding,
+ )
+ self.input_projection = Conv1d_with_init(d_input, n_channels, 1)
+ self.output_projection1 = Conv1d_with_init(n_channels, n_channels, 1)
+ self.output_projection2 = Conv1d_with_init(n_channels, 1, 1)
+ nn.init.zeros_(self.output_projection2.weight)
+
+ self.residual_layers = nn.ModuleList(
+ [
+ ResidualBlock(
+ d_side=d_side,
+ n_channels=n_channels,
+ diffusion_embedding_dim=d_diffusion_embedding,
+ nheads=n_heads,
+ )
+ for _ in range(n_layers)
+ ]
+ )
+ self.n_channels = n_channels
+
+ def forward(self, x, cond_info, diffusion_step):
+ B, input_dim, K, L = x.shape
+
+ x = x.reshape(B, input_dim, K * L)
+ x = self.input_projection(x)
+ x = F.relu(x)
+ x = x.reshape(B, self.n_channels, K, L)
+
+ diffusion_emb = self.diffusion_embedding(diffusion_step)
+
+ skip = []
+ for layer in self.residual_layers:
+ x, skip_connection = layer(x, cond_info, diffusion_emb)
+ skip.append(skip_connection)
+
+ x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
+ x = x.reshape(B, self.n_channels, K * L)
+ x = self.output_projection1(x) # (B,channel,K*L)
+ x = F.relu(x)
+ x = self.output_projection2(x) # (B,1,K*L)
+ x = x.reshape(B, K, L)
+ return x
diff --git a/pypots/classification/grud/modules/submodules.py b/pypots/modules/rnn.py
similarity index 73%
rename from pypots/classification/grud/modules/submodules.py
rename to pypots/modules/rnn.py
index 1c909e28..37011fc3 100644
--- a/pypots/classification/grud/modules/submodules.py
+++ b/pypots/modules/rnn.py
@@ -1,5 +1,5 @@
"""
-
+The implementation of some common-use modules related to RNN.
"""
# Created by Wenjie Du
@@ -15,7 +15,8 @@
class TemporalDecay(nn.Module):
- """The module used to generate the temporal decay factor gamma in the original paper.
+ """The module used to generate the temporal decay factor gamma in the GRUD model.
+ Please refer to the original paper :cite:`che2018GRUD` for more deinails.
Attributes
----------
@@ -34,6 +35,14 @@ class TemporalDecay(nn.Module):
diag : bool,
whether to product the weight with an identity matrix before forward processing
+
+ References
+ ----------
+ .. [1] `Che, Zhengping, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu.
+ "Recurrent neural networks for multivariate time series with missing values."
+ Scientific reports 8, no. 1 (2018): 6085.
+ `_
+
"""
def __init__(self, input_size: int, output_size: int, diag: bool = False):
@@ -56,16 +65,16 @@ def _reset_parameters(self) -> None:
self.b.data.uniform_(-std_dev, std_dev)
def forward(self, delta: torch.Tensor) -> torch.Tensor:
- """Forward processing of the NN module.
+ """Forward processing of this NN module.
Parameters
----------
- delta : tensor, shape [batch size, sequence length, feature number]
+ delta : tensor, shape [n_samples, n_steps, n_features]
The time gaps.
Returns
-------
- gamma : array-like, same shape with parameter `delta`, values in (0,1]
+ gamma : tensor, of the same shape with parameter `delta`, values in (0,1]
The temporal decay factor.
"""
if self.diag:
diff --git a/pypots/modules/self_attention.py b/pypots/modules/self_attention.py
index 8f42f799..6e5fe235 100644
--- a/pypots/modules/self_attention.py
+++ b/pypots/modules/self_attention.py
@@ -20,20 +20,56 @@
class ScaledDotProductAttention(nn.Module):
- """Scaled dot-product attention"""
+ """Scaled dot-product attention.
+
+ Parameters
+ ----------
+ temperature:
+ The temperature for scaling.
+
+ attn_dropout:
+ The dropout rate for the attention map.
+
+ """
def __init__(self, temperature: float, attn_dropout: float = 0.1):
super().__init__()
+ assert temperature > 0, "temperature should be positive"
+ assert attn_dropout >= 0, "dropout rate should be non-negative"
self.temperature = temperature
- self.dropout = nn.Dropout(attn_dropout)
+ self.dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else None
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
- attn_mask: torch.Tensor = None,
+ attn_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward processing of the scaled dot-product attention.
+
+ Parameters
+ ----------
+ q:
+ Query tensor.
+ k:
+ Key tensor.
+ v:
+ Value tensor.
+
+ attn_mask:
+ Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps].
+ 0 in attn_mask means values at the according position in the attention map will be masked out.
+
+ Returns
+ -------
+ output:
+ The result of Value multiplied with the scaled dot-product attention map.
+
+ attn:
+ The scaled dot-product attention map.
+
+ """
# q, k, v all have 4 dimensions [batch_size, n_heads, n_steps, d_tensor]
# d_tensor could be d_q, d_k, d_v
@@ -45,7 +81,9 @@ def forward(
attn = attn.masked_fill(attn_mask == 0, -1e9)
# compute attention score [0, 1], then apply dropout
- attn = self.dropout(F.softmax(attn, dim=-1))
+ attn = F.softmax(attn, dim=-1)
+ if self.dropout is not None:
+ attn = self.dropout(attn)
# multiply the score with v
output = torch.matmul(attn, v)
@@ -53,7 +91,29 @@ def forward(
class MultiHeadAttention(nn.Module):
- """original Transformer multi-head attention"""
+ """Transformer multi-head attention module.
+
+ Parameters
+ ----------
+ n_heads:
+ The number of heads in multi-head attention.
+
+ d_model:
+ The dimension of the input tensor.
+
+ d_k:
+ The dimension of the key and query tensor.
+
+ d_v:
+ The dimension of the value tensor.
+
+ dropout:
+ The dropout rate.
+
+ attn_dropout:
+ The dropout rate for the attention map.
+
+ """
def __init__(
self,
@@ -66,7 +126,7 @@ def __init__(
):
super().__init__()
- self.n_head = n_heads
+ self.n_heads = n_heads
self.d_k = d_k
self.d_v = d_v
@@ -87,6 +147,32 @@ def forward(
v: torch.Tensor,
attn_mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward processing of the multi-head attention module.
+
+ Parameters
+ ----------
+ q:
+ Query tensor.
+
+ k:
+ Key tensor.
+
+ v:
+ Value tensor.
+
+ attn_mask:
+ Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps].
+ 0 in attn_mask means values at the according position in the attention map will be masked out.
+
+ Returns
+ -------
+ v:
+ The output of the multi-head attention layer.
+
+ attn_weights:
+ The attention map.
+
+ """
# the input q, k, v currently have 3 dimensions [batch_size, n_steps, d_tensor]
# d_tensor could be n_heads*d_k, n_heads*d_v
@@ -95,9 +181,9 @@ def forward(
residual = q
# now separate the last dimension of q, k, v into different heads -> [batch_size, n_steps, n_heads, d_k or d_v]
- q = self.w_qs(q).view(batch_size, n_steps, self.n_head, self.d_k)
- k = self.w_ks(k).view(batch_size, n_steps, self.n_head, self.d_k)
- v = self.w_vs(v).view(batch_size, n_steps, self.n_head, self.d_v)
+ q = self.w_qs(q).view(batch_size, n_steps, self.n_heads, self.d_k)
+ k = self.w_ks(k).view(batch_size, n_steps, self.n_heads, self.d_k)
+ v = self.w_vs(v).view(batch_size, n_steps, self.n_heads, self.d_v)
# transpose for self-attention calculation -> [batch_size, n_steps, d_k or d_v, n_heads]
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
@@ -124,6 +210,21 @@ def forward(
class PositionWiseFeedForward(nn.Module):
+ """Position-wise feed forward network (FFN) in Transformer.
+
+ Parameters
+ ----------
+ d_in:
+ The dimension of the input tensor.
+
+ d_hid:
+ The dimension of the hidden layer.
+
+ dropout:
+ The dropout rate.
+
+ """
+
def __init__(self, d_in: int, d_hid: int, dropout: float = 0.1):
super().__init__()
self.linear_1 = nn.Linear(d_in, d_hid)
@@ -132,6 +233,18 @@ def __init__(self, d_in: int, d_hid: int, dropout: float = 0.1):
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward processing of the position-wise feed forward network.
+
+ Parameters
+ ----------
+ x:
+ Input tensor.
+
+ Returns
+ -------
+ x:
+ Output tensor.
+ """
# save the original input for the later residual connection
residual = x
# the 1st linear processing and ReLU non-linear projection
@@ -148,11 +261,37 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class EncoderLayer(nn.Module):
+ """Transformer encoder layer.
+
+ Parameters
+ ----------
+ d_model:
+ The dimension of the input tensor.
+
+ d_inner:
+ The dimension of the hidden layer.
+
+ n_heads:
+ The number of heads in multi-head attention.
+
+ d_k:
+ The dimension of the key and query tensor.
+
+ d_v:
+ The dimension of the value tensor.
+
+ dropout:
+ The dropout rate.
+
+ attn_dropout:
+ The dropout rate for the attention map.
+ """
+
def __init__(
self,
d_model: int,
d_inner: int,
- n_head: int,
+ n_heads: int,
d_k: int,
d_v: int,
dropout: float = 0.1,
@@ -160,7 +299,7 @@ def __init__(
):
super().__init__()
self.slf_attn = MultiHeadAttention(
- n_head, d_model, d_k, d_v, dropout, attn_dropout
+ n_heads, d_model, d_k, d_v, dropout, attn_dropout
)
self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout)
@@ -169,6 +308,25 @@ def forward(
enc_input: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward processing of the encoder layer.
+
+ Parameters
+ ----------
+ enc_input:
+ Input tensor.
+
+ src_mask:
+ Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps].
+
+ Returns
+ -------
+ enc_output:
+ Output tensor.
+
+ attn_weights:
+ The attention map.
+
+ """
enc_output, attn_weights = self.slf_attn(
enc_input,
enc_input,
@@ -180,11 +338,38 @@ def forward(
class DecoderLayer(nn.Module):
+ """Transformer decoder layer.
+
+ Parameters
+ ----------
+ d_model:
+ The dimension of the input tensor.
+
+ d_inner:
+ The dimension of the hidden layer.
+
+ n_heads:
+ The number of heads in multi-head attention.
+
+ d_k:
+ The dimension of the key and query tensor.
+
+ d_v:
+ The dimension of the value tensor.
+
+ dropout:
+ The dropout rate.
+
+ attn_dropout:
+ The dropout rate for the attention map.
+
+ """
+
def __init__(
self,
d_model: int,
d_inner: int,
- n_head: int,
+ n_heads: int,
d_k: int,
d_v: int,
dropout: float = 0.1,
@@ -192,10 +377,10 @@ def __init__(
):
super().__init__()
self.slf_attn = MultiHeadAttention(
- n_head, d_model, d_k, d_v, dropout, attn_dropout
+ n_heads, d_model, d_k, d_v, dropout, attn_dropout
)
self.enc_attn = MultiHeadAttention(
- n_head, d_model, d_k, d_v, dropout, attn_dropout
+ n_heads, d_model, d_k, d_v, dropout, attn_dropout
)
self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout)
@@ -206,6 +391,36 @@ def forward(
slf_attn_mask: Optional[torch.Tensor] = None,
dec_enc_attn_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Forward processing of the decoder layer.
+
+ Parameters
+ ----------
+ dec_input:
+ Input tensor.
+
+ enc_output:
+ Output tensor from the encoder.
+
+ slf_attn_mask:
+ Masking tensor for the self-attention module.
+ The shape should be [batch_size, n_heads, n_steps, n_steps].
+
+ dec_enc_attn_mask:
+ Masking tensor for the encoding attention module.
+ The shape should be [batch_size, n_heads, n_steps, n_steps].
+
+ Returns
+ -------
+ dec_output:
+ Output tensor.
+
+ dec_slf_attn:
+ The self-attention map.
+
+ dec_enc_attn:
+ The encoding attention map.
+
+ """
dec_output, dec_slf_attn = self.slf_attn(
dec_input, dec_input, dec_input, attn_mask=slf_attn_mask
)
@@ -217,6 +432,43 @@ def forward(
class Encoder(nn.Module):
+ """Transformer encoder.
+
+ Parameters
+ ----------
+ n_layers:
+ The number of layers in the encoder.
+
+ n_steps:
+ The number of time steps in the input tensor.
+
+ n_features:
+ The number of features in the input tensor.
+
+ d_model:
+ The dimension of the module manipulation space.
+ The input tensor will be projected to a space with d_model dimensions.
+
+ d_inner:
+ The dimension of the hidden layer in the feed-forward network.
+
+ n_heads:
+ The number of heads in multi-head attention.
+
+ d_k:
+ The dimension of the key and query tensor.
+
+ d_v:
+ The dimension of the value tensor.
+
+ dropout:
+ The dropout rate.
+
+ attn_dropout:
+ The dropout rate for the attention map.
+
+ """
+
def __init__(
self,
n_layers: int,
@@ -256,6 +508,28 @@ def forward(
src_mask: Optional[torch.Tensor] = None,
return_attn_weights: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]:
+ """Forward processing of the encoder.
+
+ Parameters
+ ----------
+ x:
+ Input tensor.
+
+ src_mask:
+ Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps].
+
+ return_attn_weights:
+ Whether to return the attention map.
+
+ Returns
+ -------
+ enc_output:
+ Output tensor.
+
+ attn_weights_collector:
+ A list containing the attention map from each encoder layer.
+
+ """
x = self.embedding(x)
enc_output = self.dropout(self.position_enc(x))
attn_weights_collector = []
@@ -271,6 +545,43 @@ def forward(
class Decoder(nn.Module):
+ """Transformer decoder.
+
+ Parameters
+ ----------
+ n_layers:
+ The number of layers in the decoder.
+
+ n_steps:
+ The number of time steps in the input tensor.
+
+ n_features:
+ The number of features in the input tensor.
+
+ d_model:
+ The dimension of the module manipulation space.
+ The input tensor will be projected to a space with d_model dimensions.
+
+ d_inner:
+ The dimension of the hidden layer in the feed-forward network.
+
+ n_heads:
+ The number of heads in multi-head attention.
+
+ d_k:
+ The dimension of the key and query tensor.
+
+ d_v:
+ The dimension of the value tensor.
+
+ dropout:
+ The dropout rate.
+
+ attn_dropout:
+ The dropout rate for the attention map.
+
+ """
+
def __init__(
self,
n_layers: int,
@@ -311,6 +622,37 @@ def forward(
src_mask: Optional[torch.Tensor] = None,
return_attn_weights: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, list, list]]:
+ """Forward processing of the decoder.
+
+ Parameters
+ ----------
+ trg_seq:
+ Input tensor.
+
+ enc_output:
+ Output tensor from the encoder.
+
+ trg_mask:
+ Masking tensor for the self-attention module.
+
+ src_mask:
+ Masking tensor for the encoding attention module.
+
+ return_attn_weights:
+ Whether to return the attention map.
+
+ Returns
+ -------
+ dec_output:
+ Output tensor.
+
+ dec_slf_attn_collector:
+ A list containing the self-attention map from each decoder layer.
+
+ dec_enc_attn_collector:
+ A list containing the encoding attention map from each decoder layer.
+
+ """
trg_seq = self.embedding(trg_seq)
dec_output = self.dropout(self.position_enc(trg_seq))
@@ -334,6 +676,18 @@ def forward(
class PositionalEncoding(nn.Module):
+ """Positional-encoding module for Transformer.
+
+ Parameters
+ ----------
+ d_hid:
+ The dimension of the hidden layer.
+
+ n_position:
+ The number of positions.
+
+ """
+
def __init__(self, d_hid: int, n_position: int = 200):
super().__init__()
# Not a parameter
@@ -359,4 +713,17 @@ def get_position_angle_vec(position):
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward processing of the positional encoding module.
+
+ Parameters
+ ----------
+ x:
+ Input tensor.
+
+ Returns
+ -------
+ x:
+ Output tensor, the input tensor with the positional encoding added.
+
+ """
return x + self.pos_table[:, : x.size(1)].clone().detach()
diff --git a/tests/imputation/csdi.py b/tests/imputation/csdi.py
new file mode 100644
index 00000000..f0c5b4b4
--- /dev/null
+++ b/tests/imputation/csdi.py
@@ -0,0 +1,108 @@
+"""
+Test cases for CSDI imputation model.
+"""
+
+# Created by Wenjie Du
+# License: GPL-v3
+
+
+import os.path
+import unittest
+
+import numpy as np
+import pytest
+
+from pypots.imputation import CSDI
+from pypots.optim import Adam
+from pypots.utils.logging import logger
+from pypots.utils.metrics import cal_mae
+from tests.global_test_config import (
+ DATA,
+ DEVICE,
+ check_tb_and_model_checkpoints_existence,
+)
+from tests.imputation.config import (
+ TRAIN_SET,
+ VAL_SET,
+ TEST_SET,
+ RESULT_SAVING_DIR_FOR_IMPUTATION,
+ EPOCHS,
+)
+
+
+class TestCSDI(unittest.TestCase):
+ logger.info("Running tests for an imputation model CSDI...")
+
+ # set the log and model saving path
+ saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "CSDI")
+ model_save_name = "saved_csdi_model.pypots"
+
+ # initialize an Adam optimizer
+ optimizer = Adam(lr=0.001, weight_decay=1e-5)
+
+ # initialize a CSDI model
+ csdi = CSDI(
+ n_features=DATA["n_features"],
+ n_layers=1,
+ n_channels=8,
+ d_time_embedding=32,
+ d_feature_embedding=3,
+ d_diffusion_embedding=32,
+ n_heads=1,
+ epochs=EPOCHS,
+ saving_path=saving_path,
+ optimizer=optimizer,
+ device=DEVICE,
+ )
+
+ @pytest.mark.xdist_group(name="imputation-csdi")
+ def test_0_fit(self):
+ self.csdi.fit(TRAIN_SET, VAL_SET)
+
+ @pytest.mark.xdist_group(name="imputation-csdi")
+ def test_1_impute(self):
+ imputed_X = self.csdi.predict(TEST_SET)["imputation"]
+ assert not np.isnan(
+ imputed_X
+ ).any(), "Output still has missing values after running impute()."
+ test_MAE = cal_mae(
+ imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"]
+ )
+ logger.info(f"CSDI test_MAE: {test_MAE}")
+
+ @pytest.mark.xdist_group(name="imputation-csdi")
+ def test_2_parameters(self):
+ assert hasattr(self.csdi, "model") and self.csdi.model is not None
+
+ assert hasattr(self.csdi, "optimizer") and self.csdi.optimizer is not None
+
+ assert hasattr(self.csdi, "best_loss")
+ self.assertNotEqual(self.csdi.best_loss, float("inf"))
+
+ assert (
+ hasattr(self.csdi, "best_model_dict")
+ and self.csdi.best_model_dict is not None
+ )
+
+ @pytest.mark.xdist_group(name="imputation-csdi")
+ def test_3_saving_path(self):
+ # whether the root saving dir exists, which should be created by save_log_into_tb_file
+ assert os.path.exists(
+ self.saving_path
+ ), f"file {self.saving_path} does not exist"
+
+ # check if the tensorboard file and model checkpoints exist
+ check_tb_and_model_checkpoints_existence(self.csdi)
+
+ # save the trained model into file, and check if the path exists
+ self.csdi.save_model(
+ saving_dir=self.saving_path, file_name=self.model_save_name
+ )
+
+ # test loading the saved model, not necessary, but need to test
+ saved_model_path = os.path.join(self.saving_path, self.model_save_name)
+ self.csdi.load_model(saved_model_path)
+
+
+if __name__ == "__main__":
+ unittest.main()