diff --git a/README.md b/README.md index bc9535b5..781404b9 100644 --- a/README.md +++ b/README.md @@ -191,9 +191,12 @@ and we are pursuing to publish it in prestigious academic venues, e.g. JMLR (tra [Machine Learning Open Source Software](https://www.jmlr.org/mloss/)). If you use PyPOTS in your work, please cite it as below and 🌟star this repository to make others notice this library. 🤗 +There are scientific research projects using PyPOTS and referencing in their papers. +Here is [an incomplete list of them](https://scholar.google.com/scholar?as_ylo=2022&q=%E2%80%9CPyPOTS%E2%80%9D&hl=en>). + ``` bibtex @article{du2023PyPOTS, -title={{PyPOTS: A Python Toolbox for Data Mining on Partially-Observed Time Series}}, +title={{PyPOTS: a Python toolbox for data mining on Partially-Observed Time Series}}, author={Wenjie Du}, year={2023}, eprint={2305.18811}, @@ -204,11 +207,25 @@ doi={10.48550/arXiv.2305.18811}, } ``` +> Wenjie Du. (2023). +> PyPOTS: a Python toolbox for data mining on Partially-Observed Time Series. +> arXiv, abs/2305.18811.https://arxiv.org/abs/2305.18811 + or +``` bibtex +@inproceedings{du2023PyPOTS, +title={{PyPOTS: a Python toolbox for data mining on Partially-Observed Time Series}}, +booktitle={9th SIGKDD workshop on Mining and Learning from Time Series (MiLeTS'23)}, +author={Wenjie Du}, +year={2023}, +url={https://arxiv.org/abs/2305.18811}, +} +``` + > Wenjie Du. (2023). -> PyPOTS: A Python Toolbox for Data Mining on Partially-Observed Time Series. -> arXiv, abs/2305.18811. https://doi.org/10.48550/arXiv.2305.18811 +> PyPOTS: a Python toolbox for data mining on Partially-Observed Time Series. +> In *9th SIGKDD workshop on Mining and Learning from Time Series (MiLeTS'23)*. https://arxiv.org/abs/2305.18811 ## ❖ Contribution diff --git a/docs/index.rst b/docs/index.rst index 87d00fb2..6fd84157 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -133,6 +133,8 @@ Task Type Algorithm ============================== ================ ========================================================================= ====== ========= Imputation Neural Network SAITS (Self-Attention-based Imputation for Time Series) 2022 :cite:`du2023SAITS` Imputation Neural Network Transformer 2017 :cite:`vaswani2017Transformer`, :cite:`du2023SAITS` +Imputation Neural Network US-GAN 2021 :cite:`miao2021SSGAN` +Imputation Neural Network GP-VAE 2020 :cite:`fortuin2020GPVAEDeep` Imputation, Classification Neural Network BRITS (Bidirectional Recurrent Imputation for Time Series) 2018 :cite:`cao2018BRITS` Imputation Neural Network M-RNN (Multi-directional Recurrent Neural Network) 2019 :cite:`yoon2019MRNN` Imputation Naive LOCF (Last Observation Carried Forward) / / @@ -159,7 +161,7 @@ please cite it as below and 🌟star `PyPOTS repository dict: + raise NotImplementedError + @abstractmethod def classify( self, diff --git a/pypots/classification/brits/model.py b/pypots/classification/brits/model.py index af1faec8..2453cbe6 100644 --- a/pypots/classification/brits/model.py +++ b/pypots/classification/brits/model.py @@ -18,102 +18,16 @@ import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F from torch.utils.data import DataLoader from .data import DatasetForBRITS -from .modules import RITS +from .modules import _BRITS from ..base import BaseNNClassifier -from ...imputation.brits.model import ( - _BRITS as imputation_BRITS, -) from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -class _BRITS(imputation_BRITS, nn.Module): - def __init__( - self, - n_steps: int, - n_features: int, - rnn_hidden_size: int, - n_classes: int, - classification_weight: float, - reconstruction_weight: float, - device: Union[str, torch.device], - ): - super().__init__(n_steps, n_features, rnn_hidden_size, device) - self.n_steps = n_steps - self.n_features = n_features - self.rnn_hidden_size = rnn_hidden_size - self.n_classes = n_classes - - # create models - self.rits_f = RITS(n_steps, n_features, rnn_hidden_size, n_classes, device) - self.rits_b = RITS(n_steps, n_features, rnn_hidden_size, n_classes, device) - self.classification_weight = classification_weight - self.reconstruction_weight = reconstruction_weight - - def impute(self, inputs: dict) -> torch.Tensor: - return super().impute(inputs) - - def forward(self, inputs: dict, training: bool = True) -> dict: - """Forward processing of BRITS. - - Parameters - ---------- - inputs : - The input data. - - training : - Whether in training mode. - - Returns - ------- - dict, A dictionary includes all results. - """ - ret_f = self.rits_f(inputs, "forward") - ret_b = self._reverse(self.rits_b(inputs, "backward")) - - classification_pred = (ret_f["prediction"] + ret_b["prediction"]) / 2 - if not training: - # if not in training mode, return the classification result only - return {"classification_pred": classification_pred} - - ret_f["classification_loss"] = F.nll_loss( - torch.log(ret_f["prediction"]), inputs["label"] - ) - ret_b["classification_loss"] = F.nll_loss( - torch.log(ret_b["prediction"]), inputs["label"] - ) - consistency_loss = self._get_consistency_loss( - ret_f["imputed_data"], ret_b["imputed_data"] - ) - classification_loss = ( - ret_f["classification_loss"] + ret_b["classification_loss"] - ) / 2 - reconstruction_loss = ( - ret_f["reconstruction_loss"] + ret_b["reconstruction_loss"] - ) / 2 - - loss = ( - consistency_loss - + reconstruction_loss * self.reconstruction_weight - + classification_loss * self.classification_weight - ) - - results = { - "classification_pred": classification_pred, - "consistency_loss": consistency_loss, - "classification_loss": classification_loss, - "reconstruction_loss": reconstruction_loss, - "loss": loss, - } - return results - - class BRITS(BaseNNClassifier): """The PyTorch implementation of the BRITS model :cite:`cao2018BRITS`. @@ -362,7 +276,7 @@ def classify( file_type: str = "h5py", ) -> np.ndarray: logger.warning( - "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + "🚨DeprecationWarning: The method classify is deprecated. Please use `predict` instead." ) result_dict = self.predict(X, file_type=file_type) return result_dict["classification"] diff --git a/pypots/classification/brits/modules.py b/pypots/classification/brits/modules.py deleted file mode 100644 index 525c58e5..00000000 --- a/pypots/classification/brits/modules.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -The implementation of the modules for BRITS. - -Refer to the paper "Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018). -BRITS: Bidirectional Recurrent Imputation for Time Series. NeurIPS 2018." - -Notes ------ -Partial implementation uses code from https://github.com/caow13/BRITS. The bugs in the original implementation -are fixed here. - -""" - -# Created by Wenjie Du -# License: GPL-v3 - -from typing import Union - -import torch -import torch.nn as nn - -from ...imputation.brits.model import ( - RITS as imputation_RITS, -) - - -class RITS(imputation_RITS): - def __init__( - self, - n_steps: int, - n_features: int, - rnn_hidden_size: int, - n_classes: int, - device: Union[str, torch.device], - ): - super().__init__(n_steps, n_features, rnn_hidden_size, device) - self.dropout = nn.Dropout(p=0.25) - self.classifier = nn.Linear(self.rnn_hidden_size, n_classes) - - def forward(self, inputs: dict, direction: str = "forward") -> dict: - ret_dict = super().forward(inputs, direction) - logits = self.classifier(ret_dict["final_hidden_state"]) - ret_dict["prediction"] = torch.softmax(logits, dim=1) - return ret_dict diff --git a/pypots/classification/brits/modules/__init__.py b/pypots/classification/brits/modules/__init__.py new file mode 100644 index 00000000..14bf83db --- /dev/null +++ b/pypots/classification/brits/modules/__init__.py @@ -0,0 +1,13 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .core import _BRITS + + +__all__ = [ + "_BRITS", +] diff --git a/pypots/classification/brits/modules/core.py b/pypots/classification/brits/modules/core.py new file mode 100644 index 00000000..6b6d7826 --- /dev/null +++ b/pypots/classification/brits/modules/core.py @@ -0,0 +1,125 @@ +""" +The implementation of BRITS for the partially-observed time-series classification task. + +Refer to the paper "Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018). +BRITS: Bidirectional Recurrent Imputation for Time Series. NeurIPS 2018." + +Notes +----- +Partial implementation uses code from https://github.com/caow13/BRITS. The bugs in the original implementation +are fixed here. + +""" + +# Created by Wenjie Du +# License: GPL-v3 + +from typing import Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ....imputation.brits.modules.core import RITS as imputation_RITS +from ....imputation.brits.modules.core import _BRITS as imputation_BRITS + + +class RITS(imputation_RITS): + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + n_classes: int, + device: Union[str, torch.device], + ): + super().__init__(n_steps, n_features, rnn_hidden_size, device) + self.dropout = nn.Dropout(p=0.25) + self.classifier = nn.Linear(self.rnn_hidden_size, n_classes) + + def forward(self, inputs: dict, direction: str = "forward") -> dict: + ret_dict = super().forward(inputs, direction) + logits = self.classifier(ret_dict["final_hidden_state"]) + ret_dict["prediction"] = torch.softmax(logits, dim=1) + return ret_dict + + +class _BRITS(imputation_BRITS, nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + n_classes: int, + classification_weight: float, + reconstruction_weight: float, + device: Union[str, torch.device], + ): + super().__init__(n_steps, n_features, rnn_hidden_size, device) + self.n_steps = n_steps + self.n_features = n_features + self.rnn_hidden_size = rnn_hidden_size + self.n_classes = n_classes + + # create models + self.rits_f = RITS(n_steps, n_features, rnn_hidden_size, n_classes, device) + self.rits_b = RITS(n_steps, n_features, rnn_hidden_size, n_classes, device) + self.classification_weight = classification_weight + self.reconstruction_weight = reconstruction_weight + + def impute(self, inputs: dict) -> torch.Tensor: + return super().impute(inputs) + + def forward(self, inputs: dict, training: bool = True) -> dict: + """Forward processing of BRITS. + + Parameters + ---------- + inputs : + The input data. + + training : + Whether in training mode. + + Returns + ------- + dict, A dictionary includes all results. + """ + ret_f = self.rits_f(inputs, "forward") + ret_b = self._reverse(self.rits_b(inputs, "backward")) + + classification_pred = (ret_f["prediction"] + ret_b["prediction"]) / 2 + if not training: + # if not in training mode, return the classification result only + return {"classification_pred": classification_pred} + + ret_f["classification_loss"] = F.nll_loss( + torch.log(ret_f["prediction"]), inputs["label"] + ) + ret_b["classification_loss"] = F.nll_loss( + torch.log(ret_b["prediction"]), inputs["label"] + ) + consistency_loss = self._get_consistency_loss( + ret_f["imputed_data"], ret_b["imputed_data"] + ) + classification_loss = ( + ret_f["classification_loss"] + ret_b["classification_loss"] + ) / 2 + reconstruction_loss = ( + ret_f["reconstruction_loss"] + ret_b["reconstruction_loss"] + ) / 2 + + loss = ( + consistency_loss + + reconstruction_loss * self.reconstruction_weight + + classification_loss * self.classification_weight + ) + + results = { + "classification_pred": classification_pred, + "consistency_loss": consistency_loss, + "classification_loss": classification_loss, + "reconstruction_loss": reconstruction_loss, + "loss": loss, + } + return results diff --git a/pypots/classification/grud/model.py b/pypots/classification/grud/model.py index 60fd6482..4deb80c4 100644 --- a/pypots/classification/grud/model.py +++ b/pypots/classification/grud/model.py @@ -14,107 +14,16 @@ import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F from torch.utils.data import DataLoader from .data import DatasetForGRUD +from .modules import _GRUD from ..base import BaseNNClassifier -from ...imputation.brits.modules import TemporalDecay from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -class _GRUD(nn.Module): - def __init__( - self, - n_steps: int, - n_features: int, - rnn_hidden_size: int, - n_classes: int, - device: Union[str, torch.device], - ): - super().__init__() - self.n_steps = n_steps - self.n_features = n_features - self.rnn_hidden_size = rnn_hidden_size - self.n_classes = n_classes - self.device = device - - # create models - self.rnn_cell = nn.GRUCell( - self.n_features * 2 + self.rnn_hidden_size, self.rnn_hidden_size - ) - self.temp_decay_h = TemporalDecay( - input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False - ) - self.temp_decay_x = TemporalDecay( - input_size=self.n_features, output_size=self.n_features, diag=True - ) - self.classifier = nn.Linear(self.rnn_hidden_size, self.n_classes) - - def forward(self, inputs: dict, training: bool = True) -> dict: - """Forward processing of GRU-D. - - Parameters - ---------- - inputs : - The input data. - - training : - Whether in training mode. - - Returns - ------- - dict, - A dictionary includes all results. - """ - values = inputs["X"] - masks = inputs["missing_mask"] - deltas = inputs["deltas"] - empirical_mean = inputs["empirical_mean"] - X_filledLOCF = inputs["X_filledLOCF"] - - hidden_state = torch.zeros( - (values.size()[0], self.rnn_hidden_size), device=values.device - ) - - for t in range(self.n_steps): - # for data, [batch, time, features] - x = values[:, t, :] # values - m = masks[:, t, :] # mask - d = deltas[:, t, :] # delta, time gap - x_filledLOCF = X_filledLOCF[:, t, :] - - gamma_h = self.temp_decay_h(d) - gamma_x = self.temp_decay_x(d) - hidden_state = hidden_state * gamma_h - - x_h = gamma_x * x_filledLOCF + (1 - gamma_x) * empirical_mean - x_replaced = m * x + (1 - m) * x_h - data_input = torch.cat([x_replaced, hidden_state, m], dim=1) - hidden_state = self.rnn_cell(data_input, hidden_state) - - logits = self.classifier(hidden_state) - classification_pred = torch.softmax(logits, dim=1) - - if not training: - # if not in training mode, return the classification result only - return {"classification_pred": classification_pred} - - torch.log(classification_pred) - classification_loss = F.nll_loss( - torch.log(classification_pred), inputs["label"] - ) - - results = { - "classification_pred": classification_pred, - "loss": classification_loss, - } - return results - - class GRUD(BaseNNClassifier): """The PyTorch implementation of the GRU-D model :cite:`che2018GRUD`. @@ -338,7 +247,7 @@ def classify( file_type: str = "h5py", ) -> np.ndarray: logger.warning( - "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + "🚨DeprecationWarning: The method classify is deprecated. Please use `predict` instead." ) result_dict = self.predict(X, file_type=file_type) return result_dict["classification"] diff --git a/pypots/classification/grud/modules/__init__.py b/pypots/classification/grud/modules/__init__.py new file mode 100644 index 00000000..22cb7b77 --- /dev/null +++ b/pypots/classification/grud/modules/__init__.py @@ -0,0 +1,14 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .core import _GRUD +from .submodules import TemporalDecay + +__all__ = [ + "_GRUD", + "TemporalDecay", +] diff --git a/pypots/classification/grud/modules/core.py b/pypots/classification/grud/modules/core.py new file mode 100644 index 00000000..92326719 --- /dev/null +++ b/pypots/classification/grud/modules/core.py @@ -0,0 +1,108 @@ +""" +The implementation of GRU-D for the partially-observed time-series imputation task. + +Refer to the paper "Che, Z., Purushotham, S., Cho, K., Sontag, D.A., & Liu, Y. (2018). +Recurrent Neural Networks for Multivariate Time Series with Missing Values. Scientific Reports." + +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +from typing import Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .submodules import TemporalDecay + + +class _GRUD(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + n_classes: int, + device: Union[str, torch.device], + ): + super().__init__() + self.n_steps = n_steps + self.n_features = n_features + self.rnn_hidden_size = rnn_hidden_size + self.n_classes = n_classes + self.device = device + + # create models + self.rnn_cell = nn.GRUCell( + self.n_features * 2 + self.rnn_hidden_size, self.rnn_hidden_size + ) + self.temp_decay_h = TemporalDecay( + input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False + ) + self.temp_decay_x = TemporalDecay( + input_size=self.n_features, output_size=self.n_features, diag=True + ) + self.classifier = nn.Linear(self.rnn_hidden_size, self.n_classes) + + def forward(self, inputs: dict, training: bool = True) -> dict: + """Forward processing of GRU-D. + + Parameters + ---------- + inputs : + The input data. + + training : + Whether in training mode. + + Returns + ------- + dict, + A dictionary includes all results. + """ + values = inputs["X"] + masks = inputs["missing_mask"] + deltas = inputs["deltas"] + empirical_mean = inputs["empirical_mean"] + X_filledLOCF = inputs["X_filledLOCF"] + + hidden_state = torch.zeros( + (values.size()[0], self.rnn_hidden_size), device=values.device + ) + + for t in range(self.n_steps): + # for data, [batch, time, features] + x = values[:, t, :] # values + m = masks[:, t, :] # mask + d = deltas[:, t, :] # delta, time gap + x_filledLOCF = X_filledLOCF[:, t, :] + + gamma_h = self.temp_decay_h(d) + gamma_x = self.temp_decay_x(d) + hidden_state = hidden_state * gamma_h + + x_h = gamma_x * x_filledLOCF + (1 - gamma_x) * empirical_mean + x_replaced = m * x + (1 - m) * x_h + data_input = torch.cat([x_replaced, hidden_state, m], dim=1) + hidden_state = self.rnn_cell(data_input, hidden_state) + + logits = self.classifier(hidden_state) + classification_pred = torch.softmax(logits, dim=1) + + if not training: + # if not in training mode, return the classification result only + return {"classification_pred": classification_pred} + + torch.log(classification_pred) + classification_loss = F.nll_loss( + torch.log(classification_pred), inputs["label"] + ) + + results = { + "classification_pred": classification_pred, + "loss": classification_loss, + } + return results diff --git a/pypots/classification/grud/modules/submodules.py b/pypots/classification/grud/modules/submodules.py new file mode 100644 index 00000000..1c909e28 --- /dev/null +++ b/pypots/classification/grud/modules/submodules.py @@ -0,0 +1,76 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from torch.nn.parameter import Parameter + + +class TemporalDecay(nn.Module): + """The module used to generate the temporal decay factor gamma in the original paper. + + Attributes + ---------- + W: tensor, + The weights (parameters) of the module. + b: tensor, + The bias of the module. + + Parameters + ---------- + input_size : int, + the feature dimension of the input + + output_size : int, + the feature dimension of the output + + diag : bool, + whether to product the weight with an identity matrix before forward processing + """ + + def __init__(self, input_size: int, output_size: int, diag: bool = False): + super().__init__() + self.diag = diag + self.W = Parameter(torch.Tensor(output_size, input_size)) + self.b = Parameter(torch.Tensor(output_size)) + + if self.diag: + assert input_size == output_size + m = torch.eye(input_size, input_size) + self.register_buffer("m", m) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + std_dev = 1.0 / math.sqrt(self.W.size(0)) + self.W.data.uniform_(-std_dev, std_dev) + if self.b is not None: + self.b.data.uniform_(-std_dev, std_dev) + + def forward(self, delta: torch.Tensor) -> torch.Tensor: + """Forward processing of the NN module. + + Parameters + ---------- + delta : tensor, shape [batch size, sequence length, feature number] + The time gaps. + + Returns + ------- + gamma : array-like, same shape with parameter `delta`, values in (0,1] + The temporal decay factor. + """ + if self.diag: + gamma = F.relu(F.linear(delta, self.W * Variable(self.m), self.b)) + else: + gamma = F.relu(F.linear(delta, self.W, self.b)) + gamma = torch.exp(-gamma) + return gamma diff --git a/pypots/classification/raindrop/model.py b/pypots/classification/raindrop/model.py index 8b9d1670..393efb64 100644 --- a/pypots/classification/raindrop/model.py +++ b/pypots/classification/raindrop/model.py @@ -15,277 +15,15 @@ import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import TransformerEncoderLayer, TransformerEncoder -from torch.nn.parameter import Parameter from torch.utils.data import DataLoader +from .modules import _Raindrop from ...classification.base import BaseNNClassifier from ...classification.grud.data import DatasetForGRUD from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -try: - from .modules import PositionalEncoding, ObservationPropagation - from torch_geometric.nn.inits import glorot -except ImportError as e: - logger.error( - f"{e}\n" - "Note torch_geometric is missing, " - "please install it with 'pip install torch_geometric' or 'conda install -c pyg pyg'" - ) -except NameError as e: - logger.error( - f"{e}\n" - "Note torch_geometric is missing, " - "please install it with 'pip install torch_geometric' or 'conda install -c pyg pyg'" - ) - - -class _Raindrop(nn.Module): - def __init__( - self, - n_features, - n_layers, - d_model, - d_inner, - n_heads, - n_classes, - dropout=0.3, - max_len=215, - d_static=9, - aggregation="mean", - sensor_wise_mask=False, - static=False, - device=None, - ): - super().__init__() - self.n_layers = n_layers - self.n_features = n_features - self.d_model = d_model - self.d_inner = d_inner - self.n_heads = n_heads - self.n_classes = n_classes - self.dropout = dropout - self.max_len = max_len - self.d_static = d_static - self.aggregation = aggregation - self.sensor_wise_mask = sensor_wise_mask - self.static = static - self.device = device - - # create modules - if self.static: - self.static_emb = nn.Linear(d_static, n_features) - else: - self.static_emb = None - assert d_model % n_features == 0, "d_model must be divisible by n_features" - self.d_ob = int(d_model / n_features) - self.encoder = nn.Linear(n_features * self.d_ob, n_features * self.d_ob) - d_pe = 16 - self.pos_encoder = PositionalEncoding(d_pe, max_len) - if self.sensor_wise_mask: - dim_check = n_features * (self.d_ob + d_pe) - assert dim_check % n_heads == 0, "dim_check must be divisible by n_heads" - encoder_layers = TransformerEncoderLayer( - n_features * (self.d_ob + d_pe), n_heads, d_inner, dropout - ) - else: - dim_check = d_model + d_pe - assert dim_check % n_heads == 0, "dim_check must be divisible by n_heads" - encoder_layers = TransformerEncoderLayer( - d_model + d_pe, n_heads, d_inner, dropout - ) - self.transformer_encoder = TransformerEncoder(encoder_layers, n_layers) - - self.R_u = Parameter(torch.Tensor(1, self.n_features * self.d_ob)) - - self.ob_propagation = ObservationPropagation( - in_channels=max_len * self.d_ob, - out_channels=max_len * self.d_ob, - heads=1, - n_nodes=n_features, - ob_dim=self.d_ob, - ) - self.ob_propagation_layer2 = ObservationPropagation( - in_channels=max_len * self.d_ob, - out_channels=max_len * self.d_ob, - heads=1, - n_nodes=n_features, - ob_dim=self.d_ob, - ) - if static: - d_final = d_model + d_pe + n_features - else: - d_final = d_model + d_pe - - self.mlp_static = nn.Sequential( - nn.Linear(d_final, d_final), - nn.ReLU(), - nn.Linear(d_final, n_classes), - ) - - self.dropout = nn.Dropout(dropout) - self.init_weights() - - def init_weights(self): - init_range = 1e-10 - self.encoder.weight.data.uniform_(-init_range, init_range) - if self.static: - self.static_emb.weight.data.uniform_(-init_range, init_range) - glorot(self.R_u) - - def classify(self, inputs: dict) -> torch.Tensor: - """Forward processing of BRITS. - - Parameters - ---------- - inputs : dict, - The input data. - - Returns - ------- - prediction : torch.Tensor - """ - src = inputs["X"].permute(1, 0, 2) - static = inputs["static"] - times = inputs["timestamps"].permute(1, 0) - lengths = inputs["lengths"] - missing_mask = inputs["missing_mask"].permute(1, 0, 2) - - max_len, batch_size = src.shape[0], src.shape[1] - - src = torch.repeat_interleave(src, self.d_ob, dim=-1) - h = F.relu(src * self.R_u) - pe = self.pos_encoder(times).to(src.device) - if static is not None: - emb = self.static_emb(static) - - h = self.dropout(h) - - mask = torch.arange(max_len)[None, :] >= (lengths.cpu()[:, None]) - mask = mask.squeeze(1).to(src.device) - - x = h - - adj = torch.ones(self.n_features, self.n_features, device=src.device) - adj[torch.eye(self.n_features, dtype=torch.bool)] = 1 - - edge_index = torch.nonzero(adj).T - edge_weights = adj[edge_index[0], edge_index[1]] - - batch_size = src.shape[1] - n_step = src.shape[0] - output = torch.zeros( - [n_step, batch_size, self.n_features * self.d_ob], device=src.device - ) - - alpha_all = torch.zeros([edge_index.shape[1], batch_size], device=src.device) - - # iterate on each sample - for unit in range(0, batch_size): - step_data = x[:, unit, :] - p_t = pe[:, unit, :] - - step_data = step_data.reshape([n_step, self.n_features, self.d_ob]).permute( - 1, 0, 2 - ) - step_data = step_data.reshape(self.n_features, n_step * self.d_ob) - - step_data, attention_weights = self.ob_propagation( - step_data, - p_t=p_t, - edge_index=edge_index, - edge_weights=edge_weights, - use_beta=False, - edge_attr=None, - return_attention_weights=True, - ) - - edge_index_layer2 = attention_weights[0] - edge_weights_layer2 = attention_weights[1].squeeze(-1) - - step_data, attention_weights = self.ob_propagation_layer2( - step_data, - p_t=p_t, - edge_index=edge_index_layer2, - edge_weights=edge_weights_layer2, - use_beta=False, - edge_attr=None, - return_attention_weights=True, - ) - - step_data = step_data.view([self.n_features, n_step, self.d_ob]) - step_data = step_data.permute([1, 0, 2]) # [n_step, n_features, d_ob] - step_data = step_data.reshape([-1, self.n_features * self.d_ob]) - - output[:, unit, :] = step_data - alpha_all[:, unit] = attention_weights[1].squeeze(-1) - - # distance = torch.cdist(alpha_all.T, alpha_all.T, p=2) - # distance = torch.mean(distance) - - if self.sensor_wise_mask: - extend_output = output.view(-1, batch_size, self.n_features, self.d_ob) - extended_pe = pe.unsqueeze(2).repeat([1, 1, self.n_features, 1]) - output = torch.cat([extend_output, extended_pe], dim=-1) - output = output.view(-1, batch_size, self.n_features * (self.d_ob + 16)) - else: - output = torch.cat([output, pe], dim=2) - - r_out = self.transformer_encoder(output, src_key_padding_mask=mask) - - lengths2 = lengths.unsqueeze(1).to(src.device) - mask2 = mask.permute(1, 0).unsqueeze(2).long() - if self.sensor_wise_mask: - output = torch.zeros( - [batch_size, self.n_features, self.d_ob + 16], device=src.device - ) - extended_missing_mask = missing_mask.view(-1, batch_size, self.n_features) - for se in range(self.n_features): - r_out = r_out.view(-1, batch_size, self.n_features, (self.d_ob + 16)) - out = r_out[:, :, se, :] - l_ = torch.sum(extended_missing_mask[:, :, se], dim=0).unsqueeze( - 1 - ) # length - out_sensor = torch.sum( - out * (1 - extended_missing_mask[:, :, se].unsqueeze(-1)), dim=0 - ) / (l_ + 1) - output[:, se, :] = out_sensor - output = output.view([-1, self.n_features * (self.d_ob + 16)]) - elif self.aggregation == "mean": - output = torch.sum(r_out * (1 - mask2), dim=0) / (lengths2 + 1) - else: - raise RuntimeError - - if static is not None: - output = torch.cat([output, emb], dim=1) - - logits = self.mlp_static(output) - prediction = torch.softmax(logits, dim=1) - - return prediction - - def forward(self, inputs, training=True): - classification_pred = self.classify(inputs) - if not training: - # if not in training mode, return the classification result only - return {"classification_pred": classification_pred} - - classification_loss = F.nll_loss( - torch.log(classification_pred), inputs["label"] - ) - - results = { - "prediction": classification_pred, - "loss": classification_loss - # 'distance': distance, - } - - return results - class Raindrop(BaseNNClassifier): """The PyTorch implementation of the Raindrop model :cite:`zhang2022Raindrop`. @@ -554,7 +292,7 @@ def classify( file_type: str = "h5py", ) -> np.ndarray: logger.warning( - "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + "🚨DeprecationWarning: The method classify is deprecated. Please use `predict` instead." ) result_dict = self.predict(X, file_type=file_type) return result_dict["classification"] diff --git a/pypots/classification/raindrop/modules/__init__.py b/pypots/classification/raindrop/modules/__init__.py new file mode 100644 index 00000000..db470980 --- /dev/null +++ b/pypots/classification/raindrop/modules/__init__.py @@ -0,0 +1,12 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .core import _Raindrop + +__all__ = [ + "_Raindrop", +] diff --git a/pypots/classification/raindrop/modules/core.py b/pypots/classification/raindrop/modules/core.py new file mode 100644 index 00000000..1091dd27 --- /dev/null +++ b/pypots/classification/raindrop/modules/core.py @@ -0,0 +1,279 @@ +""" +The implementation of Raindrop for the partially-observed time-series classification task. + +Refer to the paper "Zhang, X., Zeman, M., Tsiligkaridis, T., & Zitnik, M. (2022). +Graph-Guided Network for Irregularly Sampled Multivariate Time Series. ICLR 2022." + +""" + + +# Created by Wenjie Du +# License: GLP-v3 + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import TransformerEncoderLayer, TransformerEncoder +from torch.nn.parameter import Parameter + +from ....utils.logging import logger + +try: + from .submodules import PositionalEncoding, ObservationPropagation + from torch_geometric.nn.inits import glorot +except ImportError as e: + logger.error( + f"{e}\n" + "Note torch_geometric is missing, " + "please install it with 'pip install torch_geometric' or 'conda install -c pyg pyg'" + ) +except NameError as e: + logger.error( + f"{e}\n" + "Note torch_geometric is missing, " + "please install it with 'pip install torch_geometric' or 'conda install -c pyg pyg'" + ) + + +class _Raindrop(nn.Module): + def __init__( + self, + n_features, + n_layers, + d_model, + d_inner, + n_heads, + n_classes, + dropout=0.3, + max_len=215, + d_static=9, + aggregation="mean", + sensor_wise_mask=False, + static=False, + device=None, + ): + super().__init__() + self.n_layers = n_layers + self.n_features = n_features + self.d_model = d_model + self.d_inner = d_inner + self.n_heads = n_heads + self.n_classes = n_classes + self.dropout = dropout + self.max_len = max_len + self.d_static = d_static + self.aggregation = aggregation + self.sensor_wise_mask = sensor_wise_mask + self.static = static + self.device = device + + # create modules + if self.static: + self.static_emb = nn.Linear(d_static, n_features) + else: + self.static_emb = None + assert d_model % n_features == 0, "d_model must be divisible by n_features" + self.d_ob = int(d_model / n_features) + self.encoder = nn.Linear(n_features * self.d_ob, n_features * self.d_ob) + d_pe = 16 + self.pos_encoder = PositionalEncoding(d_pe, max_len) + if self.sensor_wise_mask: + dim_check = n_features * (self.d_ob + d_pe) + assert dim_check % n_heads == 0, "dim_check must be divisible by n_heads" + encoder_layers = TransformerEncoderLayer( + n_features * (self.d_ob + d_pe), n_heads, d_inner, dropout + ) + else: + dim_check = d_model + d_pe + assert dim_check % n_heads == 0, "dim_check must be divisible by n_heads" + encoder_layers = TransformerEncoderLayer( + d_model + d_pe, n_heads, d_inner, dropout + ) + self.transformer_encoder = TransformerEncoder(encoder_layers, n_layers) + + self.R_u = Parameter(torch.Tensor(1, self.n_features * self.d_ob)) + + self.ob_propagation = ObservationPropagation( + in_channels=max_len * self.d_ob, + out_channels=max_len * self.d_ob, + heads=1, + n_nodes=n_features, + ob_dim=self.d_ob, + ) + self.ob_propagation_layer2 = ObservationPropagation( + in_channels=max_len * self.d_ob, + out_channels=max_len * self.d_ob, + heads=1, + n_nodes=n_features, + ob_dim=self.d_ob, + ) + if static: + d_final = d_model + d_pe + n_features + else: + d_final = d_model + d_pe + + self.mlp_static = nn.Sequential( + nn.Linear(d_final, d_final), + nn.ReLU(), + nn.Linear(d_final, n_classes), + ) + + self.dropout = nn.Dropout(dropout) + self.init_weights() + + def init_weights(self): + init_range = 1e-10 + self.encoder.weight.data.uniform_(-init_range, init_range) + if self.static: + self.static_emb.weight.data.uniform_(-init_range, init_range) + glorot(self.R_u) + + def classify(self, inputs: dict) -> torch.Tensor: + """Forward processing of BRITS. + + Parameters + ---------- + inputs : dict, + The input data. + + Returns + ------- + prediction : torch.Tensor + """ + src = inputs["X"].permute(1, 0, 2) + static = inputs["static"] + times = inputs["timestamps"].permute(1, 0) + lengths = inputs["lengths"] + missing_mask = inputs["missing_mask"].permute(1, 0, 2) + + max_len, batch_size = src.shape[0], src.shape[1] + + src = torch.repeat_interleave(src, self.d_ob, dim=-1) + h = F.relu(src * self.R_u) + pe = self.pos_encoder(times).to(src.device) + if static is not None: + emb = self.static_emb(static) + + h = self.dropout(h) + + mask = torch.arange(max_len)[None, :] >= (lengths.cpu()[:, None]) + mask = mask.squeeze(1).to(src.device) + + x = h + + adj = torch.ones(self.n_features, self.n_features, device=src.device) + adj[torch.eye(self.n_features, dtype=torch.bool)] = 1 + + edge_index = torch.nonzero(adj).T + edge_weights = adj[edge_index[0], edge_index[1]] + + batch_size = src.shape[1] + n_step = src.shape[0] + output = torch.zeros( + [n_step, batch_size, self.n_features * self.d_ob], device=src.device + ) + + alpha_all = torch.zeros([edge_index.shape[1], batch_size], device=src.device) + + # iterate on each sample + for unit in range(0, batch_size): + step_data = x[:, unit, :] + p_t = pe[:, unit, :] + + step_data = step_data.reshape([n_step, self.n_features, self.d_ob]).permute( + 1, 0, 2 + ) + step_data = step_data.reshape(self.n_features, n_step * self.d_ob) + + step_data, attention_weights = self.ob_propagation( + step_data, + p_t=p_t, + edge_index=edge_index, + edge_weights=edge_weights, + use_beta=False, + edge_attr=None, + return_attention_weights=True, + ) + + edge_index_layer2 = attention_weights[0] + edge_weights_layer2 = attention_weights[1].squeeze(-1) + + step_data, attention_weights = self.ob_propagation_layer2( + step_data, + p_t=p_t, + edge_index=edge_index_layer2, + edge_weights=edge_weights_layer2, + use_beta=False, + edge_attr=None, + return_attention_weights=True, + ) + + step_data = step_data.view([self.n_features, n_step, self.d_ob]) + step_data = step_data.permute([1, 0, 2]) # [n_step, n_features, d_ob] + step_data = step_data.reshape([-1, self.n_features * self.d_ob]) + + output[:, unit, :] = step_data + alpha_all[:, unit] = attention_weights[1].squeeze(-1) + + # distance = torch.cdist(alpha_all.T, alpha_all.T, p=2) + # distance = torch.mean(distance) + + if self.sensor_wise_mask: + extend_output = output.view(-1, batch_size, self.n_features, self.d_ob) + extended_pe = pe.unsqueeze(2).repeat([1, 1, self.n_features, 1]) + output = torch.cat([extend_output, extended_pe], dim=-1) + output = output.view(-1, batch_size, self.n_features * (self.d_ob + 16)) + else: + output = torch.cat([output, pe], dim=2) + + r_out = self.transformer_encoder(output, src_key_padding_mask=mask) + + lengths2 = lengths.unsqueeze(1).to(src.device) + mask2 = mask.permute(1, 0).unsqueeze(2).long() + if self.sensor_wise_mask: + output = torch.zeros( + [batch_size, self.n_features, self.d_ob + 16], device=src.device + ) + extended_missing_mask = missing_mask.view(-1, batch_size, self.n_features) + for se in range(self.n_features): + r_out = r_out.view(-1, batch_size, self.n_features, (self.d_ob + 16)) + out = r_out[:, :, se, :] + l_ = torch.sum(extended_missing_mask[:, :, se], dim=0).unsqueeze( + 1 + ) # length + out_sensor = torch.sum( + out * (1 - extended_missing_mask[:, :, se].unsqueeze(-1)), dim=0 + ) / (l_ + 1) + output[:, se, :] = out_sensor + output = output.view([-1, self.n_features * (self.d_ob + 16)]) + elif self.aggregation == "mean": + output = torch.sum(r_out * (1 - mask2), dim=0) / (lengths2 + 1) + else: + raise RuntimeError + + if static is not None: + output = torch.cat([output, emb], dim=1) + + logits = self.mlp_static(output) + prediction = torch.softmax(logits, dim=1) + + return prediction + + def forward(self, inputs, training=True): + classification_pred = self.classify(inputs) + if not training: + # if not in training mode, return the classification result only + return {"classification_pred": classification_pred} + + classification_loss = F.nll_loss( + torch.log(classification_pred), inputs["label"] + ) + + results = { + "prediction": classification_pred, + "loss": classification_loss + # 'distance': distance, + } + + return results diff --git a/pypots/classification/raindrop/modules.py b/pypots/classification/raindrop/modules/submodules.py similarity index 99% rename from pypots/classification/raindrop/modules.py rename to pypots/classification/raindrop/modules/submodules.py index 191ff9c7..7bd31d32 100644 --- a/pypots/classification/raindrop/modules.py +++ b/pypots/classification/raindrop/modules/submodules.py @@ -22,7 +22,7 @@ from torch.nn import init from torch.nn.parameter import Parameter -from ...utils.logging import logger +from ....utils.logging import logger try: from torch_geometric.nn.conv import MessagePassing diff --git a/pypots/clustering/base.py b/pypots/clustering/base.py index 43d2e82e..72ec90bd 100644 --- a/pypots/clustering/base.py +++ b/pypots/clustering/base.py @@ -383,6 +383,14 @@ def fit( """ raise NotImplementedError + @abstractmethod + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + raise NotImplementedError + @abstractmethod def cluster( self, diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index 46805c38..ab13ce7a 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -14,108 +14,14 @@ import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F -from sklearn.cluster import KMeans from torch.utils.data import DataLoader from .data import DatasetForCRLI -from .modules import Generator, Decoder, Discriminator +from .modules import _CRLI from ..base import BaseNNClusterer from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import cal_mse - - -class _CRLI(nn.Module): - def __init__( - self, - n_steps: int, - n_features: int, - n_clusters: int, - n_generator_layers: int, - rnn_hidden_size: int, - decoder_fcn_output_dims: Optional[list], - lambda_kmeans: float, - rnn_cell_type: str = "GRU", - device: Union[str, torch.device] = "cpu", - ): - super().__init__() - self.generator = Generator( - n_generator_layers, n_features, rnn_hidden_size, rnn_cell_type, device - ) - self.discriminator = Discriminator(rnn_cell_type, n_features, device) - self.decoder = Decoder( - n_steps, rnn_hidden_size * 2, n_features, decoder_fcn_output_dims, device - ) # fully connected network is included in Decoder - self.kmeans = KMeans( - n_clusters=n_clusters, - n_init=10, # FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the - # value of `n_init` explicitly to suppress the warning. - ) # TODO: implement KMean with torch for gpu acceleration - - self.n_clusters = n_clusters - self.lambda_kmeans = lambda_kmeans - self.device = device - - def cluster(self, inputs: dict, training_object: str = "generator") -> dict: - # concat final states from generator and input it as the initial state of decoder - imputation, imputed_X, generator_fb_hidden_states = self.generator(inputs) - inputs["imputation"] = imputation - inputs["imputed_X"] = imputed_X - inputs["generator_fb_hidden_states"] = generator_fb_hidden_states - if training_object == "discriminator": - discrimination = self.discriminator(inputs) - inputs["discrimination"] = discrimination - return inputs # if only train discriminator, then no need to run decoder - - reconstruction, fcn_latent = self.decoder(inputs) - inputs["reconstruction"] = reconstruction - inputs["fcn_latent"] = fcn_latent - return inputs - - def forward( - self, - inputs: dict, - training_object: str = "generator", - training: bool = True, - ) -> dict: - assert training_object in [ - "generator", - "discriminator", - ], 'training_object should be "generator" or "discriminator"' - - X = inputs["X"] - missing_mask = inputs["missing_mask"] - batch_size, n_steps, n_features = X.shape - losses = {} - inputs = self.cluster(inputs, training_object) - if not training: - # if only run clustering, then no need to calculate loss - return inputs - - if training_object == "discriminator": - l_D = F.binary_cross_entropy_with_logits( - inputs["discrimination"], missing_mask - ) - losses["discrimination_loss"] = l_D - else: - inputs["discrimination"] = inputs["discrimination"].detach() - l_G = F.binary_cross_entropy_with_logits( - inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask - ) - l_pre = cal_mse(inputs["imputation"], X, missing_mask) - l_rec = cal_mse(inputs["reconstruction"], X, missing_mask) - HTH = torch.matmul(inputs["fcn_latent"], inputs["fcn_latent"].permute(1, 0)) - term_F = torch.nn.init.orthogonal_( - torch.randn(batch_size, self.n_clusters, device=self.device), gain=1 - ) - FTHTHF = torch.matmul(torch.matmul(term_F.permute(1, 0), HTH), term_F) - l_kmeans = torch.trace(HTH) - torch.trace(FTHTHF) # k-means loss - loss_gene = l_G + l_pre + l_rec + l_kmeans * self.lambda_kmeans - losses["generation_loss"] = loss_gene - return losses class CRLI(BaseNNClusterer): @@ -394,6 +300,7 @@ def _train_model( def fit( self, train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, file_type: str = "h5py", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader @@ -465,7 +372,7 @@ def cluster( return_latent: bool = False, ) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: logger.warning( - "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + "🚨DeprecationWarning: The method cluster is deprecated. Please use `predict` instead." ) result_dict = self.predict(X, file_type, return_latent) diff --git a/pypots/clustering/crli/modules/__init__.py b/pypots/clustering/crli/modules/__init__.py new file mode 100644 index 00000000..857206ef --- /dev/null +++ b/pypots/clustering/crli/modules/__init__.py @@ -0,0 +1,12 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .core import _CRLI + +__all__ = [ + "_CRLI", +] diff --git a/pypots/clustering/crli/modules/core.py b/pypots/clustering/crli/modules/core.py new file mode 100644 index 00000000..da653cde --- /dev/null +++ b/pypots/clustering/crli/modules/core.py @@ -0,0 +1,111 @@ +""" +The implementation of CRLI (Clustering Representation Learning on Incomplete time-series data) for +the partially-observed time-series clustering task. + +Refer to the paper "Ma, Q., Chen, C., Li, S., & Cottrell, G. W. (2021). +Learning Representations for Incomplete Time Series Clustering. AAAI 2021." + +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from typing import Union, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sklearn.cluster import KMeans + +from .submodules import Generator, Decoder, Discriminator +from ....utils.metrics import cal_mse + + +class _CRLI(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + n_clusters: int, + n_generator_layers: int, + rnn_hidden_size: int, + decoder_fcn_output_dims: Optional[list], + lambda_kmeans: float, + rnn_cell_type: str = "GRU", + device: Union[str, torch.device] = "cpu", + ): + super().__init__() + self.generator = Generator( + n_generator_layers, n_features, rnn_hidden_size, rnn_cell_type, device + ) + self.discriminator = Discriminator(rnn_cell_type, n_features, device) + self.decoder = Decoder( + n_steps, rnn_hidden_size * 2, n_features, decoder_fcn_output_dims, device + ) # fully connected network is included in Decoder + self.kmeans = KMeans( + n_clusters=n_clusters, + n_init=10, # FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the + # value of `n_init` explicitly to suppress the warning. + ) # TODO: implement KMean with torch for gpu acceleration + + self.n_clusters = n_clusters + self.lambda_kmeans = lambda_kmeans + self.device = device + + def cluster(self, inputs: dict, training_object: str = "generator") -> dict: + # concat final states from generator and input it as the initial state of decoder + imputation, imputed_X, generator_fb_hidden_states = self.generator(inputs) + inputs["imputation"] = imputation + inputs["imputed_X"] = imputed_X + inputs["generator_fb_hidden_states"] = generator_fb_hidden_states + if training_object == "discriminator": + discrimination = self.discriminator(inputs) + inputs["discrimination"] = discrimination + return inputs # if only train discriminator, then no need to run decoder + + reconstruction, fcn_latent = self.decoder(inputs) + inputs["reconstruction"] = reconstruction + inputs["fcn_latent"] = fcn_latent + return inputs + + def forward( + self, + inputs: dict, + training_object: str = "generator", + training: bool = True, + ) -> dict: + assert training_object in [ + "generator", + "discriminator", + ], 'training_object should be "generator" or "discriminator"' + + X = inputs["X"] + missing_mask = inputs["missing_mask"] + batch_size, n_steps, n_features = X.shape + losses = {} + inputs = self.cluster(inputs, training_object) + if not training: + # if only run clustering, then no need to calculate loss + return inputs + + if training_object == "discriminator": + l_D = F.binary_cross_entropy_with_logits( + inputs["discrimination"], missing_mask + ) + losses["discrimination_loss"] = l_D + else: + inputs["discrimination"] = inputs["discrimination"].detach() + l_G = F.binary_cross_entropy_with_logits( + inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask + ) + l_pre = cal_mse(inputs["imputation"], X, missing_mask) + l_rec = cal_mse(inputs["reconstruction"], X, missing_mask) + HTH = torch.matmul(inputs["fcn_latent"], inputs["fcn_latent"].permute(1, 0)) + term_F = torch.nn.init.orthogonal_( + torch.randn(batch_size, self.n_clusters, device=self.device), gain=1 + ) + FTHTHF = torch.matmul(torch.matmul(term_F.permute(1, 0), HTH), term_F) + l_kmeans = torch.trace(HTH) - torch.trace(FTHTHF) # k-means loss + loss_gene = l_G + l_pre + l_rec + l_kmeans * self.lambda_kmeans + losses["generation_loss"] = loss_gene + return losses diff --git a/pypots/clustering/crli/modules.py b/pypots/clustering/crli/modules/submodules.py similarity index 100% rename from pypots/clustering/crli/modules.py rename to pypots/clustering/crli/modules/submodules.py diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index d6d910c5..4e31b412 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -15,259 +15,16 @@ import numpy as np import torch -import torch.nn as nn from scipy.stats import multivariate_normal from sklearn.mixture import GaussianMixture from torch.utils.data import DataLoader from .data import DatasetForVaDER -from .modules import ( - inverse_softplus, - GMMLayer, - PeepholeLSTMCell, - ImplicitImputation, -) +from .modules import inverse_softplus, _VaDER from ..base import BaseNNClusterer from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import cal_mse - - -class _VaDER(nn.Module): - """ - - Parameters - ---------- - n_steps : - d_input : - n_clusters : - d_rnn_hidden : - d_mu_stddev : - eps : - alpha : - Weight of the latent loss. - The final loss = `alpha`*latent loss + reconstruction loss - - - Attributes - ---------- - - """ - - def __init__( - self, - n_steps: int, - d_input: int, - n_clusters: int, - d_rnn_hidden: int, - d_mu_stddev: int, - eps: float = 1e-9, - alpha: float = 1.0, - ): - super().__init__() - self.n_steps = n_steps - self.d_input = d_input - self.n_clusters = n_clusters - self.d_rnn_hidden = d_rnn_hidden - self.d_mu_stddev = d_mu_stddev - self.eps = eps - self.alpha = alpha - - # building model components - self.implicit_imputation_layer = ImplicitImputation(d_input) - self.encoder = PeepholeLSTMCell(d_input, d_rnn_hidden) - self.decoder = PeepholeLSTMCell(d_input, d_rnn_hidden) - self.ae_encode_layers = nn.Sequential( - nn.Linear(d_rnn_hidden, d_rnn_hidden), nn.Softplus() - ) - self.ae_decode_layers = nn.Sequential( - nn.Linear(d_mu_stddev, d_rnn_hidden), nn.Softplus() - ) - self.mu_layer = nn.Linear(d_rnn_hidden, d_mu_stddev) # layer for mean - self.stddev_layer = nn.Linear( - d_rnn_hidden, d_mu_stddev - ) # layer for standard variance - self.rnn_transform_layer = nn.Linear(d_rnn_hidden, d_input) - self.gmm_layer = GMMLayer(d_mu_stddev, n_clusters) - - @staticmethod - def z_sampling( - mu_tilde: torch.Tensor, - stddev_tilde: torch.Tensor, - ) -> torch.Tensor: - noise = mu_tilde.data.new(mu_tilde.size()).normal_() - z = torch.add(mu_tilde, torch.exp(0.5 * stddev_tilde) * noise) - return z - - def encode( - self, - X: torch.Tensor, - missing_mask: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - batch_size = X.size(0) - - X_imputed = self.implicit_imputation_layer(X, missing_mask) - - hidden_state = torch.zeros( - (batch_size, self.d_rnn_hidden), dtype=X.dtype, device=X.device - ) - cell_state = torch.zeros( - (batch_size, self.d_rnn_hidden), dtype=X.dtype, device=X.device - ) - # cell_state_collector = torch.empty((batch_size, self.n_steps, self.d_rnn_hidden), - # dtype=X.dtype, device=X.device) - for i in range(self.n_steps): - x = X_imputed[:, i, :] - hidden_state, cell_state = self.encoder(x, (hidden_state, cell_state)) - # cell_state_collector[:, i, :] = cell_state - - cell_state_collector = self.ae_encode_layers(cell_state) - mu_tilde = self.mu_layer(cell_state_collector) - stddev_tilde = self.stddev_layer(cell_state_collector) - z = self.z_sampling(mu_tilde, stddev_tilde) - return z, mu_tilde, stddev_tilde - - def decode(self, z: torch.Tensor) -> torch.Tensor: - hidden_state = z - hidden_state = self.ae_decode_layers(hidden_state) - - cell_state = torch.zeros(hidden_state.size(), dtype=z.dtype, device=z.device) - inputs = torch.zeros( - (z.size(0), self.n_steps, self.d_input), dtype=z.dtype, device=z.device - ) - - hidden_state_collector = torch.empty( - (z.size(0), self.n_steps, self.d_rnn_hidden), dtype=z.dtype, device=z.device - ) - for i in range(self.n_steps): - x = inputs[:, i, :] - hidden_state, cell_state = self.decoder(x, (hidden_state, cell_state)) - hidden_state_collector[:, i, :] = hidden_state - - reconstruction = self.rnn_transform_layer(hidden_state_collector) - return reconstruction - - def get_results( - self, X: torch.Tensor, missing_mask: torch.Tensor - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - ]: - z, mu_tilde, stddev_tilde = self.encode(X, missing_mask) - X_reconstructed = self.decode(z) - mu_c, var_c, phi_c = self.gmm_layer() - return X_reconstructed, mu_c, var_c, phi_c, z, mu_tilde, stddev_tilde - - def forward( - self, - inputs: dict, - pretrain: bool = False, - training: bool = True, - ) -> dict: - X, missing_mask = inputs["X"], inputs["missing_mask"] - device = X.device - - ( - X_reconstructed, - mu_c, - var_c, - phi_c, - z, - mu_tilde, - stddev_tilde, - ) = self.get_results(X, missing_mask) - imputed_X = X_reconstructed * (1 - missing_mask) + X * missing_mask - - if not training and not pretrain: - results = { - "mu_tilde": mu_tilde, - "stddev_tilde": stddev_tilde, - "mu": mu_c, - "var": var_c, - "phi": phi_c, - "z": z, - "imputed_X": imputed_X, - } - # if only run clustering, then no need to calculate loss - return results - - # calculate the reconstruction loss - unscaled_reconstruction_loss = cal_mse(X_reconstructed, X, missing_mask) - reconstruction_loss = ( - unscaled_reconstruction_loss - * self.n_steps - * self.d_input - / missing_mask.sum() - ) - if pretrain: - results = {"loss": reconstruction_loss, "z": z} - return results - - # calculate the latent loss - var_tilde = torch.exp(stddev_tilde) - stddev_c = torch.log(var_c + self.eps) - log_2pi = torch.log(torch.tensor([2 * torch.pi], device=device)) - log_phi_c = torch.log(phi_c + self.eps) - - batch_size = z.shape[0] - - ii, jj = torch.meshgrid( - torch.arange(self.n_clusters, dtype=torch.int64, device=device), - torch.arange(batch_size, dtype=torch.int64, device=device), - indexing="ij", - ) - ii = ii.flatten() - jj = jj.flatten() - - lsc_b = stddev_c.index_select(dim=0, index=ii) - mc_b = mu_c.index_select(dim=0, index=ii) - sc_b = var_c.index_select(dim=0, index=ii) - z_b = z.index_select(dim=0, index=jj) - log_pdf_z = -0.5 * (lsc_b + log_2pi + torch.square(z_b - mc_b) / sc_b) - log_pdf_z = log_pdf_z.reshape([batch_size, self.n_clusters, self.d_mu_stddev]) - - log_p = log_phi_c + log_pdf_z.sum(dim=2) - lse_p = log_p.logsumexp(dim=1, keepdim=True) - log_gamma_c = log_p - lse_p - gamma_c = torch.exp(log_gamma_c) - - term1 = torch.log(var_c + self.eps) - st_b = var_tilde.index_select(dim=0, index=jj) - sc_b = var_c.index_select(dim=0, index=ii) - term2 = torch.reshape( - st_b / (sc_b + self.eps), [batch_size, self.n_clusters, self.d_mu_stddev] - ) - mt_b = mu_tilde.index_select(dim=0, index=jj) - mc_b = mu_c.index_select(dim=0, index=ii) - term3 = torch.reshape( - torch.square(mt_b - mc_b) / (sc_b + self.eps), - [batch_size, self.n_clusters, self.d_mu_stddev], - ) - - latent_loss1 = 0.5 * torch.sum( - gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1 - ) - latent_loss2 = -torch.sum(gamma_c * (log_phi_c - log_gamma_c), dim=1) - latent_loss3 = -0.5 * torch.sum(1 + stddev_tilde, dim=1) - - latent_loss1 = latent_loss1.mean() - latent_loss2 = latent_loss2.mean() - latent_loss3 = latent_loss3.mean() - latent_loss = latent_loss1 + latent_loss2 + latent_loss3 - - results = { - "loss": reconstruction_loss + self.alpha * latent_loss, - "z": z, - "imputed_X": imputed_X, - } - - return results class VaDER(BaseNNClusterer): @@ -624,7 +381,7 @@ def predict( var_collector = [] phi_collector = [] z_collector = [] - imputed_X_collector = [] + imputation_latent_collector = [] clustering_results_collector = [] with torch.no_grad(): @@ -644,8 +401,8 @@ def predict( phi_collector.append(phi) z = results["z"].cpu().numpy() z_collector.append(z) - imputed_X = results["imputed_X"].cpu().numpy() - imputed_X_collector.append(imputed_X) + imputation_latent = results["imputation_latent"].cpu().numpy() + imputation_latent_collector.append(imputation_latent) def func_to_apply( mu_t_: np.ndarray, @@ -676,7 +433,7 @@ def func_to_apply( "var": np.concatenate(var_collector), "phi": np.concatenate(phi_collector), "z": np.concatenate(z_collector), - "imputation": np.concatenate(imputed_X_collector), + "imputation_latent": np.concatenate(imputation_latent_collector), } result_dict = { @@ -695,7 +452,7 @@ def cluster( return_latent: bool = False, ) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: logger.warning( - "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + "🚨DeprecationWarning: The method cluster is deprecated. Please use `predict` instead." ) result_dict = self.predict(X, file_type, return_latent) diff --git a/pypots/clustering/vader/modules/__init__.py b/pypots/clustering/vader/modules/__init__.py new file mode 100644 index 00000000..e7a47e0e --- /dev/null +++ b/pypots/clustering/vader/modules/__init__.py @@ -0,0 +1,14 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .core import _VaDER +from .submodules import inverse_softplus + +__all__ = [ + "_VaDER", + "inverse_softplus", +] diff --git a/pypots/clustering/vader/modules/core.py b/pypots/clustering/vader/modules/core.py new file mode 100644 index 00000000..1e33ba54 --- /dev/null +++ b/pypots/clustering/vader/modules/core.py @@ -0,0 +1,260 @@ +""" +The implementation of VaDER for the partially-observed time-series clustering task. + +Refer to the paper "Jong, J.D., Emon, M.A., Wu, P., Karki, R., Sood, M., Godard, P., Ahmad, A., Vrooman, H.A., +Hofmann-Apitius, M., & Fröhlich, H. (2019). +Deep learning for clustering of multivariate clinical patient trajectories with missing values. GigaScience." + +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +from typing import Tuple + +import torch +import torch.nn as nn + +from .submodules import ( + GMMLayer, + PeepholeLSTMCell, + ImplicitImputation, +) +from ....utils.metrics import cal_mse + + +class _VaDER(nn.Module): + """ + + Parameters + ---------- + n_steps : + d_input : + n_clusters : + d_rnn_hidden : + d_mu_stddev : + eps : + alpha : + Weight of the latent loss. + The final loss = `alpha`*latent loss + reconstruction loss + + + Attributes + ---------- + + """ + + def __init__( + self, + n_steps: int, + d_input: int, + n_clusters: int, + d_rnn_hidden: int, + d_mu_stddev: int, + eps: float = 1e-9, + alpha: float = 1.0, + ): + super().__init__() + self.n_steps = n_steps + self.d_input = d_input + self.n_clusters = n_clusters + self.d_rnn_hidden = d_rnn_hidden + self.d_mu_stddev = d_mu_stddev + self.eps = eps + self.alpha = alpha + + # building model components + self.implicit_imputation_layer = ImplicitImputation(d_input) + self.encoder = PeepholeLSTMCell(d_input, d_rnn_hidden) + self.decoder = PeepholeLSTMCell(d_input, d_rnn_hidden) + self.ae_encode_layers = nn.Sequential( + nn.Linear(d_rnn_hidden, d_rnn_hidden), nn.Softplus() + ) + self.ae_decode_layers = nn.Sequential( + nn.Linear(d_mu_stddev, d_rnn_hidden), nn.Softplus() + ) + self.mu_layer = nn.Linear(d_rnn_hidden, d_mu_stddev) # layer for mean + self.stddev_layer = nn.Linear( + d_rnn_hidden, d_mu_stddev + ) # layer for standard variance + self.rnn_transform_layer = nn.Linear(d_rnn_hidden, d_input) + self.gmm_layer = GMMLayer(d_mu_stddev, n_clusters) + + @staticmethod + def z_sampling( + mu_tilde: torch.Tensor, + stddev_tilde: torch.Tensor, + ) -> torch.Tensor: + noise = mu_tilde.data.new(mu_tilde.size()).normal_() + z = torch.add(mu_tilde, torch.exp(0.5 * stddev_tilde) * noise) + return z + + def encode( + self, + X: torch.Tensor, + missing_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = X.size(0) + + X_imputed = self.implicit_imputation_layer(X, missing_mask) + + hidden_state = torch.zeros( + (batch_size, self.d_rnn_hidden), dtype=X.dtype, device=X.device + ) + cell_state = torch.zeros( + (batch_size, self.d_rnn_hidden), dtype=X.dtype, device=X.device + ) + # cell_state_collector = torch.empty((batch_size, self.n_steps, self.d_rnn_hidden), + # dtype=X.dtype, device=X.device) + for i in range(self.n_steps): + x = X_imputed[:, i, :] + hidden_state, cell_state = self.encoder(x, (hidden_state, cell_state)) + # cell_state_collector[:, i, :] = cell_state + + cell_state_collector = self.ae_encode_layers(cell_state) + mu_tilde = self.mu_layer(cell_state_collector) + stddev_tilde = self.stddev_layer(cell_state_collector) + z = self.z_sampling(mu_tilde, stddev_tilde) + return z, mu_tilde, stddev_tilde + + def decode(self, z: torch.Tensor) -> torch.Tensor: + hidden_state = z + hidden_state = self.ae_decode_layers(hidden_state) + + cell_state = torch.zeros(hidden_state.size(), dtype=z.dtype, device=z.device) + inputs = torch.zeros( + (z.size(0), self.n_steps, self.d_input), dtype=z.dtype, device=z.device + ) + + hidden_state_collector = torch.empty( + (z.size(0), self.n_steps, self.d_rnn_hidden), dtype=z.dtype, device=z.device + ) + for i in range(self.n_steps): + x = inputs[:, i, :] + hidden_state, cell_state = self.decoder(x, (hidden_state, cell_state)) + hidden_state_collector[:, i, :] = hidden_state + + reconstruction = self.rnn_transform_layer(hidden_state_collector) + return reconstruction + + def get_results( + self, X: torch.Tensor, missing_mask: torch.Tensor + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + z, mu_tilde, stddev_tilde = self.encode(X, missing_mask) + X_reconstructed = self.decode(z) + mu_c, var_c, phi_c = self.gmm_layer() + return X_reconstructed, mu_c, var_c, phi_c, z, mu_tilde, stddev_tilde + + def forward( + self, + inputs: dict, + pretrain: bool = False, + training: bool = True, + ) -> dict: + X, missing_mask = inputs["X"], inputs["missing_mask"] + device = X.device + + ( + X_reconstructed, + mu_c, + var_c, + phi_c, + z, + mu_tilde, + stddev_tilde, + ) = self.get_results(X, missing_mask) + imputed_X = X_reconstructed * (1 - missing_mask) + X * missing_mask + + if not training and not pretrain: + results = { + "mu_tilde": mu_tilde, + "stddev_tilde": stddev_tilde, + "mu": mu_c, + "var": var_c, + "phi": phi_c, + "z": z, + "imputed_X": imputed_X, + } + # if only run clustering, then no need to calculate loss + return results + + # calculate the reconstruction loss + unscaled_reconstruction_loss = cal_mse(X_reconstructed, X, missing_mask) + reconstruction_loss = ( + unscaled_reconstruction_loss + * self.n_steps + * self.d_input + / missing_mask.sum() + ) + if pretrain: + results = {"loss": reconstruction_loss, "z": z} + return results + + # calculate the latent loss + var_tilde = torch.exp(stddev_tilde) + stddev_c = torch.log(var_c + self.eps) + log_2pi = torch.log(torch.tensor([2 * torch.pi], device=device)) + log_phi_c = torch.log(phi_c + self.eps) + + batch_size = z.shape[0] + + ii, jj = torch.meshgrid( + torch.arange(self.n_clusters, dtype=torch.int64, device=device), + torch.arange(batch_size, dtype=torch.int64, device=device), + indexing="ij", + ) + ii = ii.flatten() + jj = jj.flatten() + + lsc_b = stddev_c.index_select(dim=0, index=ii) + mc_b = mu_c.index_select(dim=0, index=ii) + sc_b = var_c.index_select(dim=0, index=ii) + z_b = z.index_select(dim=0, index=jj) + log_pdf_z = -0.5 * (lsc_b + log_2pi + torch.square(z_b - mc_b) / sc_b) + log_pdf_z = log_pdf_z.reshape([batch_size, self.n_clusters, self.d_mu_stddev]) + + log_p = log_phi_c + log_pdf_z.sum(dim=2) + lse_p = log_p.logsumexp(dim=1, keepdim=True) + log_gamma_c = log_p - lse_p + gamma_c = torch.exp(log_gamma_c) + + term1 = torch.log(var_c + self.eps) + st_b = var_tilde.index_select(dim=0, index=jj) + sc_b = var_c.index_select(dim=0, index=ii) + term2 = torch.reshape( + st_b / (sc_b + self.eps), [batch_size, self.n_clusters, self.d_mu_stddev] + ) + mt_b = mu_tilde.index_select(dim=0, index=jj) + mc_b = mu_c.index_select(dim=0, index=ii) + term3 = torch.reshape( + torch.square(mt_b - mc_b) / (sc_b + self.eps), + [batch_size, self.n_clusters, self.d_mu_stddev], + ) + + latent_loss1 = 0.5 * torch.sum( + gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1 + ) + latent_loss2 = -torch.sum(gamma_c * (log_phi_c - log_gamma_c), dim=1) + latent_loss3 = -0.5 * torch.sum(1 + stddev_tilde, dim=1) + + latent_loss1 = latent_loss1.mean() + latent_loss2 = latent_loss2.mean() + latent_loss3 = latent_loss3.mean() + latent_loss = latent_loss1 + latent_loss2 + latent_loss3 + + results = { + "loss": reconstruction_loss + self.alpha * latent_loss, + "z": z, + "imputation_latent": X_reconstructed, + } + + return results diff --git a/pypots/clustering/vader/modules.py b/pypots/clustering/vader/modules/submodules.py similarity index 100% rename from pypots/clustering/vader/modules.py rename to pypots/clustering/vader/modules/submodules.py diff --git a/pypots/data/generating.py b/pypots/data/generating.py index f0a20473..4b462c2c 100644 --- a/pypots/data/generating.py +++ b/pypots/data/generating.py @@ -350,6 +350,8 @@ def gene_physionet2012(artificially_missing_rate: float = 0.1): dataset = load_specific_dataset("physionet_2012") X = dataset["X"] y = dataset["y"] + ICUType = dataset["ICUType"] + all_recordID = X["RecordID"].unique() train_set_ids, test_set_ids = train_test_split(all_recordID, test_size=0.2) train_set_ids, val_set_ids = train_test_split(train_set_ids, test_size=0.2) @@ -385,16 +387,28 @@ def gene_physionet2012(artificially_missing_rate: float = 0.1): test_y = y[y.index.isin(test_set_ids)].sort_index() train_y, val_y, test_y = train_y.to_numpy(), val_y.to_numpy(), test_y.to_numpy() + train_ICUType = ICUType[ICUType.index.isin(train_set_ids)].sort_index() + val_ICUType = ICUType[ICUType.index.isin(val_set_ids)].sort_index() + test_ICUType = ICUType[ICUType.index.isin(test_set_ids)].sort_index() + train_ICUType, val_ICUType, test_ICUType = ( + train_ICUType.to_numpy(), + val_ICUType.to_numpy(), + test_ICUType.to_numpy(), + ) + data = { "n_classes": 2, "n_steps": 48, "n_features": train_X.shape[-1], "train_X": train_X, "train_y": train_y.flatten(), + "train_ICUType": train_ICUType.flatten(), "val_X": val_X, "val_y": val_y.flatten(), + "val_ICUType": val_ICUType.flatten(), "test_X": test_X, "test_y": test_y.flatten(), + "test_ICUType": test_ICUType.flatten(), "scaler": scaler, } diff --git a/pypots/data/load_preprocessing.py b/pypots/data/load_preprocessing.py index 0968ab5b..789cc0ce 100644 --- a/pypots/data/load_preprocessing.py +++ b/pypots/data/load_preprocessing.py @@ -25,7 +25,8 @@ def preprocess_physionet2012(data: dict) -> dict: y : pandas.Series The 11988 classification labels of all patients, indicating whether they were deceased. """ - # remove the static features, e.g. age, gender + data["static_features"].remove("ICUType") # keep ICUType for now + # remove the other static features, e.g. age, gender X = data["X"].drop(data["static_features"], axis=1) def apply_func(df_temp): # pad and truncate to set the max length of samples as 48 @@ -41,11 +42,13 @@ def apply_func(df_temp): # pad and truncate to set the max length of samples as X = X.groupby("RecordID").apply(apply_func) X = X.drop("RecordID", axis=1) X = X.reset_index() - X = X.drop(["level_1"], axis=1) + ICUType = X[["RecordID", "ICUType"]].set_index("RecordID").dropna() + X = X.drop(["level_1", "ICUType"], axis=1) dataset = { "X": X, "y": data["y"], + "ICUType": ICUType, } return dataset diff --git a/pypots/data/utils.py b/pypots/data/utils.py index 9d859be5..662579e3 100644 --- a/pypots/data/utils.py +++ b/pypots/data/utils.py @@ -194,7 +194,7 @@ def cal_delta_for_single_sample(mask: np.ndarray) -> np.ndarray: return delta -def sliding_window(time_series, n_steps, sliding_len=None): +def sliding_window(time_series, window_len, sliding_len=None): """Generate time series samples with sliding window method, truncating windows from time-series data with a given sequence length. @@ -208,11 +208,11 @@ def sliding_window(time_series, n_steps, sliding_len=None): time_series : np.ndarray, time series data, len(shape)=2, [total_length, feature_num] - n_steps : int, - The number of time steps in the generated data samples. + window_len : int, + The length of the sliding window, i.e. the number of time steps in the generated data samples. sliding_len : int, default = None, - The size of the sliding window. It will be set as the same with n_steps if None. + The sliding length of the window for each moving step. It will be set as the same with n_steps if None. Returns ------- @@ -220,17 +220,17 @@ def sliding_window(time_series, n_steps, sliding_len=None): The generated time-series data samples of shape [seq_len//sliding_len, n_steps, n_features]. """ - sliding_len = n_steps if sliding_len is None else sliding_len + sliding_len = window_len if sliding_len is None else sliding_len total_len = time_series.shape[0] start_indices = np.asarray(range(total_len // sliding_len)) * sliding_len # remove the last one if left length is not enough - if total_len - start_indices[-1] * sliding_len < n_steps: + if total_len - start_indices[-1] * sliding_len < window_len: start_indices = start_indices[:-1] sample_collector = [] for idx in start_indices: - sample_collector.append(time_series[idx : idx + n_steps]) + sample_collector.append(time_series[idx : idx + window_len]) samples = np.asarray(sample_collector).astype("float32") diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index 7189dabb..493a3b56 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -369,6 +369,14 @@ def fit( """ raise NotImplementedError + @abstractmethod + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + raise NotImplementedError + @abstractmethod def forecast( self, diff --git a/pypots/forecasting/bttf/model.py b/pypots/forecasting/bttf/model.py index 3f5aa9e6..fb25029e 100644 --- a/pypots/forecasting/bttf/model.py +++ b/pypots/forecasting/bttf/model.py @@ -20,255 +20,12 @@ import numpy as np import torch -from numpy.linalg import inv as inv -from numpy.linalg import solve as solve -from scipy.linalg import khatri_rao as kr_prod -from .modules import ( - mvnrnd_pre, - ten2mat, - sample_factor_u, - sample_factor_v, - sample_factor_x, - sample_var_coefficient, - ar4cast, -) +from .modules import BTTF_forecast from ..base import BaseForecaster from ...utils.logging import logger -def _BTTF( - dense_tensor, - sparse_tensor, - init, - rank, - time_lags, - burn_iter, - gibbs_iter, - multi_step=1, -): - """Bayesian Temporal Tensor Factorization, BTTF.""" - - dim1, dim2, dim3 = sparse_tensor.shape - d = time_lags.shape[0] - U = init["U"] - V = init["V"] - X = init["X"] - if not np.isnan(sparse_tensor).any(): - ind = sparse_tensor != 0 - pos_test = np.where((dense_tensor != 0) & (sparse_tensor == 0)) - elif np.isnan(sparse_tensor).any(): - pos_test = np.where((dense_tensor != 0) & (np.isnan(sparse_tensor))) - ind = ~np.isnan(sparse_tensor) - # pos_obs = np.where(ind) - sparse_tensor[np.isnan(sparse_tensor)] = 0 - # dense_test = dense_tensor[pos_test] - del dense_tensor - U_plus = np.zeros((dim1, rank, gibbs_iter)) - V_plus = np.zeros((dim2, rank, gibbs_iter)) - X_plus = np.zeros((dim3 + multi_step, rank, gibbs_iter)) - A_plus = np.zeros((rank * d, rank, gibbs_iter)) - tau_plus = np.zeros(gibbs_iter) - Sigma_plus = np.zeros((rank, rank, gibbs_iter)) - temp_hat = np.zeros(len(pos_test[0])) - show_iter = 500 - tau = 1 - tensor_hat_plus = np.zeros(sparse_tensor.shape) - tensor_new_plus = np.zeros((dim1, dim2, multi_step)) - for it in range(burn_iter + gibbs_iter): - tau_ind = tau * ind - tau_sparse_tensor = tau * sparse_tensor - U = sample_factor_u(tau_sparse_tensor, tau_ind, U, V, X) - V = sample_factor_v(tau_sparse_tensor, tau_ind, U, V, X) - A, Sigma = sample_var_coefficient(X, time_lags) - X = sample_factor_x( - tau_sparse_tensor, tau_ind, time_lags, U, V, X, A, inv(Sigma) - ) - tensor_hat = np.einsum("is, js, ts -> ijt", U, V, X) - tau = np.random.gamma( - 1e-6 + 0.5 * np.sum(ind), - 1 / (1e-6 + 0.5 * np.sum(((sparse_tensor - tensor_hat) ** 2) * ind)), - ) - temp_hat += tensor_hat[pos_test] - if (it + 1) % show_iter == 0 and it < burn_iter: - # temp_hat = temp_hat / show_iter - # logger.info('Iter: {}'.format(it + 1)) - # logger.info('MAPE: {:.6}'.format(compute_mape(dense_test, temp_hat))) - # logger.info('RMSE: {:.6}'.format(compute_rmse(dense_test, temp_hat))) - temp_hat = np.zeros(len(pos_test[0])) - if it + 1 > burn_iter: - U_plus[:, :, it - burn_iter] = U - V_plus[:, :, it - burn_iter] = V - A_plus[:, :, it - burn_iter] = A - Sigma_plus[:, :, it - burn_iter] = Sigma - tau_plus[it - burn_iter] = tau - tensor_hat_plus += tensor_hat - X0 = ar4cast(A, X, Sigma, time_lags, multi_step) - X_plus[:, :, it - burn_iter] = X0 - tensor_new_plus += np.einsum("is, js, ts -> ijt", U, V, X0[-multi_step:, :]) - tensor_hat = tensor_hat_plus / gibbs_iter - # logger.info('Imputation MAPE: {:.6}'.format(compute_mape(dense_test, tensor_hat[:, :, : dim3][pos_test]))) - # logger.info('Imputation RMSE: {:.6}'.format(compute_rmse(dense_test, tensor_hat[:, :, : dim3][pos_test]))) - tensor_hat = np.append(tensor_hat, tensor_new_plus / gibbs_iter, axis=2) - tensor_hat[tensor_hat < 0] = 0 - - return tensor_hat, U_plus, V_plus, X_plus, A_plus, Sigma_plus, tau_plus - - -def sample_factor_x_partial( - tau_sparse_tensor, tau_ind, time_lags, U, V, X, A, Lambda_x, back_step -): - """Sampling T-by-R factor matrix X.""" - - dim3, rank = X.shape - tmax = np.max(time_lags) - tmin = np.min(time_lags) - d = time_lags.shape[0] - A0 = np.dstack([A] * d) - for k in range(d): - A0[k * rank : (k + 1) * rank, :, k] = 0 - mat0 = Lambda_x @ A.T - mat1 = np.einsum("kij, jt -> kit", A.reshape([d, rank, rank]), Lambda_x) - mat2 = np.einsum("kit, kjt -> ij", mat1, A.reshape([d, rank, rank])) - - var1 = kr_prod(V, U).T - var2 = kr_prod(var1, var1) - var3 = (var2 @ ten2mat(tau_ind[:, :, -back_step:], 2).T).reshape( - [rank, rank, back_step] - ) + Lambda_x[:, :, None] - var4 = var1 @ ten2mat(tau_sparse_tensor[:, :, -back_step:], 2).T - for t in range(dim3 - back_step, dim3): - Mt = np.zeros((rank, rank)) - Nt = np.zeros(rank) - Qt = mat0 @ X[t - time_lags, :].reshape(rank * d) - index = list(range(0, d)) - if dim3 - tmax <= t < dim3 - tmin: - index = list(np.where(t + time_lags < dim3))[0] - if t < dim3 - tmin: - Mt = mat2.copy() - temp = np.zeros((rank * d, len(index))) - n = 0 - for k in index: - temp[:, n] = X[t + time_lags[k] - time_lags, :].reshape(rank * d) - n += 1 - temp0 = X[t + time_lags[index], :].T - np.einsum( - "ijk, ik -> jk", A0[:, :, index], temp - ) - Nt = np.einsum("kij, jk -> i", mat1[index, :, :], temp0) - var3[:, :, t + back_step - dim3] = var3[:, :, t + back_step - dim3] + Mt - X[t, :] = mvnrnd_pre( - solve( - var3[:, :, t + back_step - dim3], - var4[:, t + back_step - dim3] + Nt + Qt, - ), - var3[:, :, t + back_step - dim3], - ) - return X - - -def _BTTF_partial( - sparse_tensor, init, rank, time_lags, gibbs_iter, multi_step=1, gamma=10 -): - """Bayesian Temporal Tensor Factorization, BTTF.""" - - dim1, dim2, dim3 = sparse_tensor.shape - U_plus = init["U_plus"] - V_plus = init["V_plus"] - X_plus = init["X_plus"] - A_plus = init["A_plus"] - Sigma_plus = init["Sigma_plus"] - tau_plus = init["tau_plus"] - if not np.isnan(sparse_tensor).any(): - ind = sparse_tensor != 0 - elif np.isnan(sparse_tensor).any(): - ind = ~np.isnan(sparse_tensor) - sparse_tensor[np.isnan(sparse_tensor)] = 0 - X_new_plus = np.zeros((dim3 + multi_step, rank, gibbs_iter)) - tensor_new_plus = np.zeros((dim1, dim2, multi_step)) - back_step = gamma * multi_step - for it in range(gibbs_iter): - tau_ind = tau_plus[it] * ind - tau_sparse_tensor = tau_plus[it] * sparse_tensor - X = sample_factor_x_partial( - tau_sparse_tensor, - tau_ind, - time_lags, - U_plus[:, :, it], - V_plus[:, :, it], - X_plus[:, :, it], - A_plus[:, :, it], - inv(Sigma_plus[:, :, it]), - back_step, - ) - X0 = ar4cast(A_plus[:, :, it], X, Sigma_plus[:, :, it], time_lags, multi_step) - X_new_plus[:, :, it] = X0 - tensor_new_plus += np.einsum( - "is, js, ts -> ijt", U_plus[:, :, it], V_plus[:, :, it], X0[-multi_step:, :] - ) - tensor_hat = tensor_new_plus / gibbs_iter - tensor_hat[tensor_hat < 0] = 0 - - return tensor_hat, U_plus, V_plus, X_new_plus, A_plus, Sigma_plus, tau_plus - - -def BTTF_forecast( - dense_tensor, - sparse_tensor, - pred_step, - multi_step, - rank, - time_lags, - burn_iter, - gibbs_iter, - gamma=10, -): - dim1, dim2, T = dense_tensor.shape - start_time = T - pred_step - max_count = int(np.ceil(pred_step / multi_step)) - tensor_hat = np.zeros((dim1, dim2, max_count * multi_step)) - - # t==0 - init = { - "U": 0.1 * np.random.randn(dim1, rank), - "V": 0.1 * np.random.randn(dim2, rank), - "X": 0.1 * np.random.randn(start_time, rank), - } - tensor, U, V, X_new, A, Sigma, tau = _BTTF( - dense_tensor[:, :, :start_time], - sparse_tensor[:, :, :start_time], - init, - rank, - time_lags, - burn_iter, - gibbs_iter, - multi_step, - ) - tensor_hat[:, :, 0:multi_step] = tensor[:, :, -multi_step:] - # 1<= t `_ + `_ """ @@ -393,7 +150,7 @@ def forecast( file_type: str = "h5py", ) -> np.ndarray: logger.warning( - "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + "🚨DeprecationWarning: The method forecast is deprecated. Please use `predict` instead." ) result_dict = self.predict(X, file_type=file_type) forecasting = result_dict["forecasting"] diff --git a/pypots/forecasting/bttf/modules/__init__.py b/pypots/forecasting/bttf/modules/__init__.py new file mode 100644 index 00000000..b0fd2d2e --- /dev/null +++ b/pypots/forecasting/bttf/modules/__init__.py @@ -0,0 +1,12 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .core import BTTF_forecast + +__all__ = [ + "BTTF_forecast", +] diff --git a/pypots/forecasting/bttf/modules/core.py b/pypots/forecasting/bttf/modules/core.py new file mode 100644 index 00000000..63c41744 --- /dev/null +++ b/pypots/forecasting/bttf/modules/core.py @@ -0,0 +1,263 @@ +""" +The implementation of BTTF (Bayesian Temporal Tensor Factorization) for the partially-observed time-series +forecasting task. + +Refer to the paper "Chen, X., & Sun, L. (2021). +Bayesian Temporal Factorization for Multidimensional Time Series Prediction. +IEEE transactions on pattern analysis and machine intelligence." + +Notes +----- +This numpy implementation is the same with the official one from https://github.com/xinychen/transdim. + +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import numpy as np +from numpy.linalg import inv as inv +from numpy.linalg import solve as solve +from scipy.linalg import khatri_rao as kr_prod + +from .submodules import ( + mvnrnd_pre, + ten2mat, + sample_factor_u, + sample_factor_v, + sample_factor_x, + sample_var_coefficient, + ar4cast, +) + + +def _BTTF( + dense_tensor, + sparse_tensor, + init, + rank, + time_lags, + burn_iter, + gibbs_iter, + multi_step=1, +): + """Bayesian Temporal Tensor Factorization, BTTF.""" + + dim1, dim2, dim3 = sparse_tensor.shape + d = time_lags.shape[0] + U = init["U"] + V = init["V"] + X = init["X"] + if not np.isnan(sparse_tensor).any(): + ind = sparse_tensor != 0 + pos_test = np.where((dense_tensor != 0) & (sparse_tensor == 0)) + elif np.isnan(sparse_tensor).any(): + pos_test = np.where((dense_tensor != 0) & (np.isnan(sparse_tensor))) + ind = ~np.isnan(sparse_tensor) + # pos_obs = np.where(ind) + sparse_tensor[np.isnan(sparse_tensor)] = 0 + # dense_test = dense_tensor[pos_test] + del dense_tensor + U_plus = np.zeros((dim1, rank, gibbs_iter)) + V_plus = np.zeros((dim2, rank, gibbs_iter)) + X_plus = np.zeros((dim3 + multi_step, rank, gibbs_iter)) + A_plus = np.zeros((rank * d, rank, gibbs_iter)) + tau_plus = np.zeros(gibbs_iter) + Sigma_plus = np.zeros((rank, rank, gibbs_iter)) + temp_hat = np.zeros(len(pos_test[0])) + show_iter = 500 + tau = 1 + tensor_hat_plus = np.zeros(sparse_tensor.shape) + tensor_new_plus = np.zeros((dim1, dim2, multi_step)) + for it in range(burn_iter + gibbs_iter): + tau_ind = tau * ind + tau_sparse_tensor = tau * sparse_tensor + U = sample_factor_u(tau_sparse_tensor, tau_ind, U, V, X) + V = sample_factor_v(tau_sparse_tensor, tau_ind, U, V, X) + A, Sigma = sample_var_coefficient(X, time_lags) + X = sample_factor_x( + tau_sparse_tensor, tau_ind, time_lags, U, V, X, A, inv(Sigma) + ) + tensor_hat = np.einsum("is, js, ts -> ijt", U, V, X) + tau = np.random.gamma( + 1e-6 + 0.5 * np.sum(ind), + 1 / (1e-6 + 0.5 * np.sum(((sparse_tensor - tensor_hat) ** 2) * ind)), + ) + temp_hat += tensor_hat[pos_test] + if (it + 1) % show_iter == 0 and it < burn_iter: + # temp_hat = temp_hat / show_iter + # logger.info('Iter: {}'.format(it + 1)) + # logger.info('MAPE: {:.6}'.format(compute_mape(dense_test, temp_hat))) + # logger.info('RMSE: {:.6}'.format(compute_rmse(dense_test, temp_hat))) + temp_hat = np.zeros(len(pos_test[0])) + if it + 1 > burn_iter: + U_plus[:, :, it - burn_iter] = U + V_plus[:, :, it - burn_iter] = V + A_plus[:, :, it - burn_iter] = A + Sigma_plus[:, :, it - burn_iter] = Sigma + tau_plus[it - burn_iter] = tau + tensor_hat_plus += tensor_hat + X0 = ar4cast(A, X, Sigma, time_lags, multi_step) + X_plus[:, :, it - burn_iter] = X0 + tensor_new_plus += np.einsum("is, js, ts -> ijt", U, V, X0[-multi_step:, :]) + tensor_hat = tensor_hat_plus / gibbs_iter + # logger.info('Imputation MAPE: {:.6}'.format(compute_mape(dense_test, tensor_hat[:, :, : dim3][pos_test]))) + # logger.info('Imputation RMSE: {:.6}'.format(compute_rmse(dense_test, tensor_hat[:, :, : dim3][pos_test]))) + tensor_hat = np.append(tensor_hat, tensor_new_plus / gibbs_iter, axis=2) + tensor_hat[tensor_hat < 0] = 0 + + return tensor_hat, U_plus, V_plus, X_plus, A_plus, Sigma_plus, tau_plus + + +def sample_factor_x_partial( + tau_sparse_tensor, tau_ind, time_lags, U, V, X, A, Lambda_x, back_step +): + """Sampling T-by-R factor matrix X.""" + + dim3, rank = X.shape + tmax = np.max(time_lags) + tmin = np.min(time_lags) + d = time_lags.shape[0] + A0 = np.dstack([A] * d) + for k in range(d): + A0[k * rank : (k + 1) * rank, :, k] = 0 + mat0 = Lambda_x @ A.T + mat1 = np.einsum("kij, jt -> kit", A.reshape([d, rank, rank]), Lambda_x) + mat2 = np.einsum("kit, kjt -> ij", mat1, A.reshape([d, rank, rank])) + + var1 = kr_prod(V, U).T + var2 = kr_prod(var1, var1) + var3 = (var2 @ ten2mat(tau_ind[:, :, -back_step:], 2).T).reshape( + [rank, rank, back_step] + ) + Lambda_x[:, :, None] + var4 = var1 @ ten2mat(tau_sparse_tensor[:, :, -back_step:], 2).T + for t in range(dim3 - back_step, dim3): + Mt = np.zeros((rank, rank)) + Nt = np.zeros(rank) + Qt = mat0 @ X[t - time_lags, :].reshape(rank * d) + index = list(range(0, d)) + if dim3 - tmax <= t < dim3 - tmin: + index = list(np.where(t + time_lags < dim3))[0] + if t < dim3 - tmin: + Mt = mat2.copy() + temp = np.zeros((rank * d, len(index))) + n = 0 + for k in index: + temp[:, n] = X[t + time_lags[k] - time_lags, :].reshape(rank * d) + n += 1 + temp0 = X[t + time_lags[index], :].T - np.einsum( + "ijk, ik -> jk", A0[:, :, index], temp + ) + Nt = np.einsum("kij, jk -> i", mat1[index, :, :], temp0) + var3[:, :, t + back_step - dim3] = var3[:, :, t + back_step - dim3] + Mt + X[t, :] = mvnrnd_pre( + solve( + var3[:, :, t + back_step - dim3], + var4[:, t + back_step - dim3] + Nt + Qt, + ), + var3[:, :, t + back_step - dim3], + ) + return X + + +def _BTTF_partial( + sparse_tensor, init, rank, time_lags, gibbs_iter, multi_step=1, gamma=10 +): + """Bayesian Temporal Tensor Factorization, BTTF.""" + + dim1, dim2, dim3 = sparse_tensor.shape + U_plus = init["U_plus"] + V_plus = init["V_plus"] + X_plus = init["X_plus"] + A_plus = init["A_plus"] + Sigma_plus = init["Sigma_plus"] + tau_plus = init["tau_plus"] + if not np.isnan(sparse_tensor).any(): + ind = sparse_tensor != 0 + elif np.isnan(sparse_tensor).any(): + ind = ~np.isnan(sparse_tensor) + sparse_tensor[np.isnan(sparse_tensor)] = 0 + X_new_plus = np.zeros((dim3 + multi_step, rank, gibbs_iter)) + tensor_new_plus = np.zeros((dim1, dim2, multi_step)) + back_step = gamma * multi_step + for it in range(gibbs_iter): + tau_ind = tau_plus[it] * ind + tau_sparse_tensor = tau_plus[it] * sparse_tensor + X = sample_factor_x_partial( + tau_sparse_tensor, + tau_ind, + time_lags, + U_plus[:, :, it], + V_plus[:, :, it], + X_plus[:, :, it], + A_plus[:, :, it], + inv(Sigma_plus[:, :, it]), + back_step, + ) + X0 = ar4cast(A_plus[:, :, it], X, Sigma_plus[:, :, it], time_lags, multi_step) + X_new_plus[:, :, it] = X0 + tensor_new_plus += np.einsum( + "is, js, ts -> ijt", U_plus[:, :, it], V_plus[:, :, it], X0[-multi_step:, :] + ) + tensor_hat = tensor_new_plus / gibbs_iter + tensor_hat[tensor_hat < 0] = 0 + + return tensor_hat, U_plus, V_plus, X_new_plus, A_plus, Sigma_plus, tau_plus + + +def BTTF_forecast( + dense_tensor, + sparse_tensor, + pred_step, + multi_step, + rank, + time_lags, + burn_iter, + gibbs_iter, + gamma=10, +): + dim1, dim2, T = dense_tensor.shape + start_time = T - pred_step + max_count = int(np.ceil(pred_step / multi_step)) + tensor_hat = np.zeros((dim1, dim2, max_count * multi_step)) + + # t==0 + init = { + "U": 0.1 * np.random.randn(dim1, rank), + "V": 0.1 * np.random.randn(dim2, rank), + "X": 0.1 * np.random.randn(start_time, rank), + } + tensor, U, V, X_new, A, Sigma, tau = _BTTF( + dense_tensor[:, :, :start_time], + sparse_tensor[:, :, :start_time], + init, + rank, + time_lags, + burn_iter, + gibbs_iter, + multi_step, + ) + tensor_hat[:, :, 0:multi_step] = tensor[:, :, -multi_step:] + # 1<= t dict: + raise NotImplementedError + @abstractmethod def impute( self, diff --git a/pypots/imputation/brits/model.py b/pypots/imputation/brits/model.py index 54a47331..dda114fa 100644 --- a/pypots/imputation/brits/model.py +++ b/pypots/imputation/brits/model.py @@ -14,339 +14,19 @@ # Created by Wenjie Du # License: GPL-v3 -from typing import Tuple, Union, Optional +from typing import Union, Optional import h5py import numpy as np import torch -import torch.nn as nn from torch.utils.data import DataLoader from .data import DatasetForBRITS -from .modules import TemporalDecay, FeatureRegression +from .modules import _BRITS from ..base import BaseNNImputer from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import cal_mae - - -class RITS(nn.Module): - """model RITS: Recurrent Imputation for Time Series - - Attributes - ---------- - n_steps : - sequence length (number of time steps) - - n_features : - number of features (input dimensions) - - rnn_hidden_size : - the hidden size of the RNN cell - - device : - specify running the model on which device, CPU/GPU - - rnn_cell : - the LSTM cell to model temporal data - - temp_decay_h : - the temporal decay module to decay RNN hidden state - - temp_decay_x : - the temporal decay module to decay data in the raw feature space - - hist_reg : - the temporal-regression module to project RNN hidden state into the raw feature space - - feat_reg : - the feature-regression module - - combining_weight : - the module used to generate the weight to combine history regression and feature regression - - Parameters - ---------- - n_steps : - sequence length (number of time steps) - - n_features : - number of features (input dimensions) - - rnn_hidden_size : - the hidden size of the RNN cell - - device : - specify running the model on which device, CPU/GPU - - """ - - def __init__( - self, - n_steps: int, - n_features: int, - rnn_hidden_size: int, - device: Union[str, torch.device], - ): - super().__init__() - self.n_steps = n_steps - self.n_features = n_features - self.rnn_hidden_size = rnn_hidden_size - self.device = device - - self.rnn_cell = nn.LSTMCell(self.n_features * 2, self.rnn_hidden_size) - self.temp_decay_h = TemporalDecay( - input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False - ) - self.temp_decay_x = TemporalDecay( - input_size=self.n_features, output_size=self.n_features, diag=True - ) - self.hist_reg = nn.Linear(self.rnn_hidden_size, self.n_features) - self.feat_reg = FeatureRegression(self.n_features) - self.combining_weight = nn.Linear(self.n_features * 2, self.n_features) - - def impute( - self, inputs: dict, direction: str - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """The imputation function. - Parameters - ---------- - inputs : - Input data, a dictionary includes feature values, missing masks, and time-gap values. - - direction : - A keyword to extract data from parameter `data`. - - Returns - ------- - imputed_data : - [batch size, sequence length, feature number] - - hidden_states: tensor, - [batch size, RNN hidden size] - - reconstruction_loss : - reconstruction loss - - """ - values = inputs[direction]["X"] # feature values - masks = inputs[direction]["missing_mask"] # missing masks - deltas = inputs[direction]["deltas"] # time-gap values - - # create hidden states and cell states for the lstm cell - hidden_states = torch.zeros( - (values.size()[0], self.rnn_hidden_size), device=values.device - ) - cell_states = torch.zeros( - (values.size()[0], self.rnn_hidden_size), device=values.device - ) - - estimations = [] - reconstruction_loss = torch.tensor(0.0).to(values.device) - - # imputation period - for t in range(self.n_steps): - # data shape: [batch, time, features] - x = values[:, t, :] # values - m = masks[:, t, :] # mask - d = deltas[:, t, :] # delta, time gap - - gamma_h = self.temp_decay_h(d) - gamma_x = self.temp_decay_x(d) - - hidden_states = hidden_states * gamma_h # decay hidden states - x_h = self.hist_reg(hidden_states) - reconstruction_loss += cal_mae(x_h, x, m) - - x_c = m * x + (1 - m) * x_h - - z_h = self.feat_reg(x_c) - reconstruction_loss += cal_mae(z_h, x, m) - - alpha = torch.sigmoid(self.combining_weight(torch.cat([gamma_x, m], dim=1))) - - c_h = alpha * z_h + (1 - alpha) * x_h - reconstruction_loss += cal_mae(c_h, x, m) - - c_c = m * x + (1 - m) * c_h - estimations.append(c_h.unsqueeze(dim=1)) - - inputs = torch.cat([c_c, m], dim=1) - hidden_states, cell_states = self.rnn_cell( - inputs, (hidden_states, cell_states) - ) - - estimations = torch.cat(estimations, dim=1) - imputed_data = masks * values + (1 - masks) * estimations - return imputed_data, hidden_states, reconstruction_loss - - def forward(self, inputs: dict, direction: str = "forward") -> dict: - """Forward processing of the NN module. - Parameters - ---------- - inputs : - The input data. - - direction : - A keyword to extract data from parameter `data`. - - Returns - ------- - dict, - A dictionary includes all results. - - """ - imputed_data, hidden_state, reconstruction_loss = self.impute(inputs, direction) - # for each iteration, reconstruction_loss increases its value for 3 times - reconstruction_loss /= self.n_steps * 3 - - ret_dict = { - "consistency_loss": torch.tensor( - 0.0, device=imputed_data.device - ), # single direction, has no consistency loss - "reconstruction_loss": reconstruction_loss, - "imputed_data": imputed_data, - "final_hidden_state": hidden_state, - } - return ret_dict - - -class _BRITS(nn.Module): - """model BRITS: Bidirectional RITS - BRITS consists of two RITS, which take time-series data from two directions (forward/backward) respectively. - - Attributes - ---------- - n_steps : - sequence length (number of time steps) - - n_features : - number of features (input dimensions) - - rnn_hidden_size : - the hidden size of the RNN cell - - rits_f: RITS object - the forward RITS model - - rits_b: RITS object - the backward RITS model - - """ - - def __init__( - self, - n_steps: int, - n_features: int, - rnn_hidden_size: int, - device: Union[str, torch.device], - ): - super().__init__() - # data settings - self.n_steps = n_steps - self.n_features = n_features - # imputer settings - self.rnn_hidden_size = rnn_hidden_size - # create models - self.rits_f = RITS(n_steps, n_features, rnn_hidden_size, device) - self.rits_b = RITS(n_steps, n_features, rnn_hidden_size, device) - - @staticmethod - def _get_consistency_loss( - pred_f: torch.Tensor, pred_b: torch.Tensor - ) -> torch.Tensor: - """Calculate the consistency loss between the imputation from two RITS models. - - Parameters - ---------- - pred_f : - The imputation from the forward RITS. - - pred_b : - The imputation from the backward RITS (already gets reverted). - - Returns - ------- - float tensor, - The consistency loss. - - """ - loss = torch.abs(pred_f - pred_b).mean() * 1e-1 - return loss - - @staticmethod - def _reverse(ret: dict) -> dict: - """Reverse the array values on the time dimension in the given dictionary. - - Parameters - ---------- - ret : - - Returns - ------- - dict, - A dictionary contains values reversed on the time dimension from the given dict. - - """ - - def reverse_tensor(tensor_): - if tensor_.dim() <= 1: - return tensor_ - indices = range(tensor_.size()[1])[::-1] - indices = torch.tensor( - indices, dtype=torch.long, device=tensor_.device, requires_grad=False - ) - return tensor_.index_select(1, indices) - - for key in ret: - ret[key] = reverse_tensor(ret[key]) - - return ret - - def forward(self, inputs: dict, training: bool = True) -> dict: - """Forward processing of BRITS. - - Parameters - ---------- - inputs : - The input data. - - Returns - ------- - dict, A dictionary includes all results. - """ - # Results from the forward RITS. - ret_f = self.rits_f(inputs, "forward") - # Results from the backward RITS. - ret_b = self._reverse(self.rits_b(inputs, "backward")) - - imputed_data = (ret_f["imputed_data"] + ret_b["imputed_data"]) / 2 - - if not training: - # if not in training mode, return the classification result only - return { - "imputed_data": imputed_data, - } - - consistency_loss = self._get_consistency_loss( - ret_f["imputed_data"], ret_b["imputed_data"] - ) - - # `loss` is always the item for backward propagating to update the model - loss = ( - consistency_loss - + ret_f["reconstruction_loss"] - + ret_b["reconstruction_loss"] - ) - - results = { - "imputed_data": imputed_data, - "consistency_loss": consistency_loss, - "loss": loss, # will be used for backward propagating to update the model - } - - return results class BRITS(BaseNNImputer): diff --git a/pypots/imputation/brits/modules/__init__.py b/pypots/imputation/brits/modules/__init__.py new file mode 100644 index 00000000..53e8ca2b --- /dev/null +++ b/pypots/imputation/brits/modules/__init__.py @@ -0,0 +1,14 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .core import _BRITS +from .submodules import FeatureRegression + +__all__ = [ + "_BRITS", + "FeatureRegression", +] diff --git a/pypots/imputation/brits/modules/core.py b/pypots/imputation/brits/modules/core.py new file mode 100644 index 00000000..e5c29698 --- /dev/null +++ b/pypots/imputation/brits/modules/core.py @@ -0,0 +1,342 @@ +""" +The implementation of BRITS for the partially-observed time-series imputation task. + +Refer to the paper "Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018). +BRITS: Bidirectional Recurrent Imputation for Time Series. NeurIPS 2018." + +Notes +----- +Partial implementation uses code from https://github.com/caow13/BRITS. The bugs in the original implementation +are fixed here. + +""" + +# Created by Wenjie Du +# License: GPL-v3 + +from typing import Tuple, Union + +import torch +import torch.nn as nn + +from .submodules import FeatureRegression +from ....classification.grud.modules import TemporalDecay +from ....utils.metrics import cal_mae + + +class RITS(nn.Module): + """model RITS: Recurrent Imputation for Time Series + + Attributes + ---------- + n_steps : + sequence length (number of time steps) + + n_features : + number of features (input dimensions) + + rnn_hidden_size : + the hidden size of the RNN cell + + device : + specify running the model on which device, CPU/GPU + + rnn_cell : + the LSTM cell to model temporal data + + temp_decay_h : + the temporal decay module to decay RNN hidden state + + temp_decay_x : + the temporal decay module to decay data in the raw feature space + + hist_reg : + the temporal-regression module to project RNN hidden state into the raw feature space + + feat_reg : + the feature-regression module + + combining_weight : + the module used to generate the weight to combine history regression and feature regression + + Parameters + ---------- + n_steps : + sequence length (number of time steps) + + n_features : + number of features (input dimensions) + + rnn_hidden_size : + the hidden size of the RNN cell + + device : + specify running the model on which device, CPU/GPU + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + device: Union[str, torch.device], + ): + super().__init__() + self.n_steps = n_steps + self.n_features = n_features + self.rnn_hidden_size = rnn_hidden_size + self.device = device + + self.rnn_cell = nn.LSTMCell(self.n_features * 2, self.rnn_hidden_size) + self.temp_decay_h = TemporalDecay( + input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False + ) + self.temp_decay_x = TemporalDecay( + input_size=self.n_features, output_size=self.n_features, diag=True + ) + self.hist_reg = nn.Linear(self.rnn_hidden_size, self.n_features) + self.feat_reg = FeatureRegression(self.n_features) + self.combining_weight = nn.Linear(self.n_features * 2, self.n_features) + + def impute( + self, inputs: dict, direction: str + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """The imputation function. + Parameters + ---------- + inputs : + Input data, a dictionary includes feature values, missing masks, and time-gap values. + + direction : + A keyword to extract data from parameter `data`. + + Returns + ------- + imputed_data : + [batch size, sequence length, feature number] + + hidden_states: tensor, + [batch size, RNN hidden size] + + reconstruction_loss : + reconstruction loss + + """ + values = inputs[direction]["X"] # feature values + masks = inputs[direction]["missing_mask"] # missing masks + deltas = inputs[direction]["deltas"] # time-gap values + + # create hidden states and cell states for the lstm cell + hidden_states = torch.zeros( + (values.size()[0], self.rnn_hidden_size), device=values.device + ) + cell_states = torch.zeros( + (values.size()[0], self.rnn_hidden_size), device=values.device + ) + + estimations = [] + reconstruction_loss = torch.tensor(0.0).to(values.device) + + # imputation period + for t in range(self.n_steps): + # data shape: [batch, time, features] + x = values[:, t, :] # values + m = masks[:, t, :] # mask + d = deltas[:, t, :] # delta, time gap + + gamma_h = self.temp_decay_h(d) + gamma_x = self.temp_decay_x(d) + + hidden_states = hidden_states * gamma_h # decay hidden states + x_h = self.hist_reg(hidden_states) + reconstruction_loss += cal_mae(x_h, x, m) + + x_c = m * x + (1 - m) * x_h + + z_h = self.feat_reg(x_c) + reconstruction_loss += cal_mae(z_h, x, m) + + alpha = torch.sigmoid(self.combining_weight(torch.cat([gamma_x, m], dim=1))) + + c_h = alpha * z_h + (1 - alpha) * x_h + reconstruction_loss += cal_mae(c_h, x, m) + + c_c = m * x + (1 - m) * c_h + estimations.append(c_h.unsqueeze(dim=1)) + + inputs = torch.cat([c_c, m], dim=1) + hidden_states, cell_states = self.rnn_cell( + inputs, (hidden_states, cell_states) + ) + + estimations = torch.cat(estimations, dim=1) + imputed_data = masks * values + (1 - masks) * estimations + return imputed_data, hidden_states, reconstruction_loss + + def forward(self, inputs: dict, direction: str = "forward") -> dict: + """Forward processing of the NN module. + Parameters + ---------- + inputs : + The input data. + + direction : + A keyword to extract data from parameter `data`. + + Returns + ------- + dict, + A dictionary includes all results. + + """ + imputed_data, hidden_state, reconstruction_loss = self.impute(inputs, direction) + # for each iteration, reconstruction_loss increases its value for 3 times + reconstruction_loss /= self.n_steps * 3 + + ret_dict = { + "consistency_loss": torch.tensor( + 0.0, device=imputed_data.device + ), # single direction, has no consistency loss + "reconstruction_loss": reconstruction_loss, + "imputed_data": imputed_data, + "final_hidden_state": hidden_state, + } + return ret_dict + + +class _BRITS(nn.Module): + """model BRITS: Bidirectional RITS + BRITS consists of two RITS, which take time-series data from two directions (forward/backward) respectively. + + Attributes + ---------- + n_steps : + sequence length (number of time steps) + + n_features : + number of features (input dimensions) + + rnn_hidden_size : + the hidden size of the RNN cell + + rits_f: RITS object + the forward RITS model + + rits_b: RITS object + the backward RITS model + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + device: Union[str, torch.device], + ): + super().__init__() + # data settings + self.n_steps = n_steps + self.n_features = n_features + # imputer settings + self.rnn_hidden_size = rnn_hidden_size + # create models + self.rits_f = RITS(n_steps, n_features, rnn_hidden_size, device) + self.rits_b = RITS(n_steps, n_features, rnn_hidden_size, device) + + @staticmethod + def _get_consistency_loss( + pred_f: torch.Tensor, pred_b: torch.Tensor + ) -> torch.Tensor: + """Calculate the consistency loss between the imputation from two RITS models. + + Parameters + ---------- + pred_f : + The imputation from the forward RITS. + + pred_b : + The imputation from the backward RITS (already gets reverted). + + Returns + ------- + float tensor, + The consistency loss. + + """ + loss = torch.abs(pred_f - pred_b).mean() * 1e-1 + return loss + + @staticmethod + def _reverse(ret: dict) -> dict: + """Reverse the array values on the time dimension in the given dictionary. + + Parameters + ---------- + ret : + + Returns + ------- + dict, + A dictionary contains values reversed on the time dimension from the given dict. + + """ + + def reverse_tensor(tensor_): + if tensor_.dim() <= 1: + return tensor_ + indices = range(tensor_.size()[1])[::-1] + indices = torch.tensor( + indices, dtype=torch.long, device=tensor_.device, requires_grad=False + ) + return tensor_.index_select(1, indices) + + for key in ret: + ret[key] = reverse_tensor(ret[key]) + + return ret + + def forward(self, inputs: dict, training: bool = True) -> dict: + """Forward processing of BRITS. + + Parameters + ---------- + inputs : + The input data. + + Returns + ------- + dict, A dictionary includes all results. + """ + # Results from the forward RITS. + ret_f = self.rits_f(inputs, "forward") + # Results from the backward RITS. + ret_b = self._reverse(self.rits_b(inputs, "backward")) + + imputed_data = (ret_f["imputed_data"] + ret_b["imputed_data"]) / 2 + + if not training: + # if not in training mode, return the classification result only + return { + "imputed_data": imputed_data, + } + + consistency_loss = self._get_consistency_loss( + ret_f["imputed_data"], ret_b["imputed_data"] + ) + + # `loss` is always the item for backward propagating to update the model + loss = ( + consistency_loss + + ret_f["reconstruction_loss"] + + ret_b["reconstruction_loss"] + ) + + results = { + "imputed_data": imputed_data, + "consistency_loss": consistency_loss, + "loss": loss, # will be used for backward propagating to update the model + } + + return results diff --git a/pypots/imputation/brits/modules.py b/pypots/imputation/brits/modules/submodules.py similarity index 54% rename from pypots/imputation/brits/modules.py rename to pypots/imputation/brits/modules/submodules.py index 8de4d6ac..1058b57f 100644 --- a/pypots/imputation/brits/modules.py +++ b/pypots/imputation/brits/modules/submodules.py @@ -76,65 +76,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ output = F.linear(x, self.W * Variable(self.m), self.b) return output - - -class TemporalDecay(nn.Module): - """The module used to generate the temporal decay factor gamma in the original paper. - - Attributes - ---------- - W: tensor, - The weights (parameters) of the module. - b: tensor, - The bias of the module. - - Parameters - ---------- - input_size : int, - the feature dimension of the input - - output_size : int, - the feature dimension of the output - - diag : bool, - whether to product the weight with an identity matrix before forward processing - """ - - def __init__(self, input_size: int, output_size: int, diag: bool = False): - super().__init__() - self.diag = diag - self.W = Parameter(torch.Tensor(output_size, input_size)) - self.b = Parameter(torch.Tensor(output_size)) - - if self.diag: - assert input_size == output_size - m = torch.eye(input_size, input_size) - self.register_buffer("m", m) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - std_dev = 1.0 / math.sqrt(self.W.size(0)) - self.W.data.uniform_(-std_dev, std_dev) - if self.b is not None: - self.b.data.uniform_(-std_dev, std_dev) - - def forward(self, delta: torch.Tensor) -> torch.Tensor: - """Forward processing of the NN module. - - Parameters - ---------- - delta : tensor, shape [batch size, sequence length, feature number] - The time gaps. - - Returns - ------- - gamma : array-like, same shape with parameter `delta`, values in (0,1] - The temporal decay factor. - """ - if self.diag: - gamma = F.relu(F.linear(delta, self.W * Variable(self.m), self.b)) - else: - gamma = F.relu(F.linear(delta, self.W, self.b)) - gamma = torch.exp(-gamma) - return gamma diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index 56602c72..fa381290 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -15,212 +15,16 @@ import h5py import numpy as np import torch -import torch.nn as nn from torch.utils.data import DataLoader from .data import DatasetForGPVAE -from .modules import ( - Encoder, - rbf_kernel, - diffusion_kernel, - matern_kernel, - cauchy_kernel, - Decoder, -) +from .modules import _GPVAE from ..base import BaseNNImputer from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -class _GPVAE(nn.Module): - """model GPVAE with Gaussian Process prior - - Parameters - ---------- - input_dim : int, - the feature dimension of the input - - time_length : int, - the length of each time series - - latent_dim : int, - the feature dimension of the latent embedding - - encoder_sizes : tuple, - the tuple of the network size in encoder - - decoder_sizes : tuple, - the tuple of the network size in decoder - - beta : float, - the weight of the KL divergence - - M : int, - the number of Monte Carlo samples for ELBO estimation - - K : int, - the number of importance weights for IWAE model - - kernel : str, - the Gaussian Process kernel ["cauchy", "diffusion", "rbf", "matern"] - - sigma : float, - the scale parameter for a kernel function - - length_scale : float, - the length scale parameter for a kernel function - - kernel_scales : int, - the number of different length scales over latent space dimensions - """ - - def __init__( - self, - input_dim, - time_length, - latent_dim, - encoder_sizes=(64, 64), - decoder_sizes=(64, 64), - beta=1, - M=1, - K=1, - kernel="cauchy", - sigma=1.0, - length_scale=7.0, - kernel_scales=1, - window_size=24, - ): - super().__init__() - self.kernel = kernel - self.sigma = sigma - self.length_scale = length_scale - self.kernel_scales = kernel_scales - - self.input_dim = input_dim - self.time_length = time_length - self.latent_dim = latent_dim - self.beta = beta - self.encoder = Encoder(input_dim, latent_dim, encoder_sizes, window_size) - self.decoder = Decoder(latent_dim, input_dim, decoder_sizes) - self.M = M - self.K = K - - # Precomputed KL components for efficiency - self.prior = self._init_prior() - # self.pz_scale_inv = None - # self.pz_scale_log_abs_determinant = None - - def encode(self, x): - return self.encoder(x) - - def decode(self, z): - if not torch.is_tensor(z): - z = torch.tensor(z).float() - num_dim = len(z.shape) - assert num_dim > 2 - return self.decoder(torch.transpose(z, num_dim - 1, num_dim - 2)) - - def forward(self, inputs, training=True): - x = inputs["X"] - m_mask = inputs["missing_mask"] - x = x.repeat(self.M * self.K, 1, 1) - if m_mask is not None: - m_mask = m_mask.repeat(self.M * self.K, 1, 1) - m_mask = m_mask.type(torch.bool) - - # pz = self.prior() - qz_x = self.encode(x) - z = qz_x.rsample() - px_z = self.decode(z) - - nll = -px_z.log_prob(x) - nll = torch.where(torch.isfinite(nll), nll, torch.zeros_like(nll)) - if m_mask is not None: - nll = torch.where(m_mask, nll, torch.zeros_like(nll)) - nll = nll.sum(dim=(1, 2)) - - if self.K > 1: - kl = qz_x.log_prob(z) - self.prior.log_prob(z) - kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) - kl = kl.sum(1) - - weights = -nll - kl - weights = torch.reshape(weights, [self.M, self.K, -1]) - - elbo = torch.logsumexp(weights, dim=1) - elbo = elbo.mean() - else: - kl = self.kl_divergence(qz_x, self.prior) - kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) - kl = kl.sum(1) - - elbo = -nll - self.beta * kl - elbo = elbo.mean() - - imputed_data = self.decode(self.encode(x).mean).mean * ~m_mask + x * m_mask - - if not training: - # if not in training mode, return the classification result only - return { - "imputed_data": imputed_data, - } - - results = { - "loss": -elbo.mean(), - "imputed_data": imputed_data, - } - return results - - @staticmethod - def kl_divergence(a, b): - return torch.distributions.kl.kl_divergence(a, b) - - def _init_prior(self): - # Compute kernel matrices for each latent dimension - kernel_matrices = [] - for i in range(self.kernel_scales): - if self.kernel == "rbf": - kernel_matrices.append( - rbf_kernel(self.time_length, self.length_scale / 2**i) - ) - elif self.kernel == "diffusion": - kernel_matrices.append( - diffusion_kernel(self.time_length, self.length_scale / 2**i) - ) - elif self.kernel == "matern": - kernel_matrices.append( - matern_kernel(self.time_length, self.length_scale / 2**i) - ) - elif self.kernel == "cauchy": - kernel_matrices.append( - cauchy_kernel( - self.time_length, self.sigma, self.length_scale / 2**i - ) - ) - - # Combine kernel matrices for each latent dimension - tiled_matrices = [] - total = 0 - for i in range(self.kernel_scales): - if i == self.kernel_scales - 1: - multiplier = self.latent_dim - total - else: - multiplier = int(np.ceil(self.latent_dim / self.kernel_scales)) - total += multiplier - tiled_matrices.append( - torch.unsqueeze(kernel_matrices[i], 0).repeat(multiplier, 1, 1) - ) - kernel_matrix_tiled = torch.cat(tiled_matrices) - assert len(kernel_matrix_tiled) == self.latent_dim - prior = torch.distributions.MultivariateNormal( - loc=torch.zeros(self.latent_dim, self.time_length), - covariance_matrix=kernel_matrix_tiled, - ) - - return prior - - class GPVAE(BaseNNImputer): """The PyTorch implementation of the GPVAE model :cite:`fortuin2020GPVAEDeep`. diff --git a/pypots/imputation/gpvae/modules/__init__.py b/pypots/imputation/gpvae/modules/__init__.py new file mode 100644 index 00000000..6e24f8e7 --- /dev/null +++ b/pypots/imputation/gpvae/modules/__init__.py @@ -0,0 +1,13 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +from .core import _GPVAE + +__all__ = [ + "_GPVAE", +] diff --git a/pypots/imputation/gpvae/modules/core.py b/pypots/imputation/gpvae/modules/core.py new file mode 100644 index 00000000..395493e7 --- /dev/null +++ b/pypots/imputation/gpvae/modules/core.py @@ -0,0 +1,213 @@ +""" +The implementation of GP-VAE for the partially-observed time-series imputation task. + +Refer to the paper Fortuin V, Baranchuk D, Rätsch G, et al. +GP-VAE: Deep probabilistic time series imputation. AISTATS. PMLR, 2020: 1651-1661. + +""" + +# Created by Jun Wang and Wenjie Du +# License: GPL-v3 + + +import numpy as np +import torch +import torch.nn as nn + +from .submodules import ( + Encoder, + rbf_kernel, + diffusion_kernel, + matern_kernel, + cauchy_kernel, + Decoder, +) + + +class _GPVAE(nn.Module): + """model GPVAE with Gaussian Process prior + + Parameters + ---------- + input_dim : int, + the feature dimension of the input + + time_length : int, + the length of each time series + + latent_dim : int, + the feature dimension of the latent embedding + + encoder_sizes : tuple, + the tuple of the network size in encoder + + decoder_sizes : tuple, + the tuple of the network size in decoder + + beta : float, + the weight of the KL divergence + + M : int, + the number of Monte Carlo samples for ELBO estimation + + K : int, + the number of importance weights for IWAE model + + kernel : str, + the Gaussian Process kernel ["cauchy", "diffusion", "rbf", "matern"] + + sigma : float, + the scale parameter for a kernel function + + length_scale : float, + the length scale parameter for a kernel function + + kernel_scales : int, + the number of different length scales over latent space dimensions + """ + + def __init__( + self, + input_dim, + time_length, + latent_dim, + encoder_sizes=(64, 64), + decoder_sizes=(64, 64), + beta=1, + M=1, + K=1, + kernel="cauchy", + sigma=1.0, + length_scale=7.0, + kernel_scales=1, + window_size=24, + ): + super().__init__() + self.kernel = kernel + self.sigma = sigma + self.length_scale = length_scale + self.kernel_scales = kernel_scales + + self.input_dim = input_dim + self.time_length = time_length + self.latent_dim = latent_dim + self.beta = beta + self.encoder = Encoder(input_dim, latent_dim, encoder_sizes, window_size) + self.decoder = Decoder(latent_dim, input_dim, decoder_sizes) + self.M = M + self.K = K + + self.prior = None + + def encode(self, x): + return self.encoder(x) + + def decode(self, z): + if not torch.is_tensor(z): + z = torch.tensor(z).float() + num_dim = len(z.shape) + assert num_dim > 2 + return self.decoder(torch.transpose(z, num_dim - 1, num_dim - 2)) + + def forward(self, inputs, training=True): + x = inputs["X"] + m_mask = inputs["missing_mask"] + x = x.repeat(self.M * self.K, 1, 1) + + if self.prior is None: + self.prior = self._init_prior(device=x.device) + + if m_mask is not None: + m_mask = m_mask.repeat(self.M * self.K, 1, 1) + m_mask = m_mask.type(torch.bool) + + # pz = self.prior() + qz_x = self.encode(x) + z = qz_x.rsample() + px_z = self.decode(z) + + nll = -px_z.log_prob(x) + nll = torch.where(torch.isfinite(nll), nll, torch.zeros_like(nll)) + if m_mask is not None: + nll = torch.where(m_mask, nll, torch.zeros_like(nll)) + nll = nll.sum(dim=(1, 2)) + + if self.K > 1: + kl = qz_x.log_prob(z) - self.prior.log_prob(z) + kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) + kl = kl.sum(1) + + weights = -nll - kl + weights = torch.reshape(weights, [self.M, self.K, -1]) + + elbo = torch.logsumexp(weights, dim=1) + elbo = elbo.mean() + else: + kl = self.kl_divergence(qz_x, self.prior) + kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) + kl = kl.sum(1) + + elbo = -nll - self.beta * kl + elbo = elbo.mean() + + imputed_data = self.decode(self.encode(x).mean).mean * ~m_mask + x * m_mask + + if not training: + # if not in training mode, return the classification result only + return { + "imputed_data": imputed_data, + } + + results = { + "loss": -elbo.mean(), + "imputed_data": imputed_data, + } + return results + + @staticmethod + def kl_divergence(a, b): + return torch.distributions.kl.kl_divergence(a, b) + + def _init_prior(self, device="cpu"): + # Compute kernel matrices for each latent dimension + kernel_matrices = [] + for i in range(self.kernel_scales): + if self.kernel == "rbf": + kernel_matrices.append( + rbf_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "diffusion": + kernel_matrices.append( + diffusion_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "matern": + kernel_matrices.append( + matern_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "cauchy": + kernel_matrices.append( + cauchy_kernel( + self.time_length, self.sigma, self.length_scale / 2**i + ) + ) + + # Combine kernel matrices for each latent dimension + tiled_matrices = [] + total = 0 + for i in range(self.kernel_scales): + if i == self.kernel_scales - 1: + multiplier = self.latent_dim - total + else: + multiplier = int(np.ceil(self.latent_dim / self.kernel_scales)) + total += multiplier + tiled_matrices.append( + torch.unsqueeze(kernel_matrices[i], 0).repeat(multiplier, 1, 1) + ) + kernel_matrix_tiled = torch.cat(tiled_matrices) + assert len(kernel_matrix_tiled) == self.latent_dim + prior = torch.distributions.MultivariateNormal( + loc=torch.zeros(self.latent_dim, self.time_length, device=device), + covariance_matrix=kernel_matrix_tiled.to(device), + ) + + return prior diff --git a/pypots/imputation/gpvae/modules.py b/pypots/imputation/gpvae/modules/submodules.py similarity index 100% rename from pypots/imputation/gpvae/modules.py rename to pypots/imputation/gpvae/modules/submodules.py diff --git a/pypots/imputation/mrnn/model.py b/pypots/imputation/mrnn/model.py index afdde3b4..afff8223 100644 --- a/pypots/imputation/mrnn/model.py +++ b/pypots/imputation/mrnn/model.py @@ -13,98 +13,14 @@ import h5py import numpy as np import torch -import torch.nn as nn from torch.utils.data import DataLoader from .data import DatasetForMRNN -from .module import FCN_Regression +from .modules import _MRNN from ..base import BaseNNImputer from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import cal_rmse - - -class _MRNN(nn.Module): - def __init__(self, seq_len, feature_num, rnn_hidden_size, device): - super().__init__() - # data settings - self.seq_len = seq_len - self.feature_num = feature_num - self.rnn_hidden_size = rnn_hidden_size - self.device = device - - self.f_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size) - self.b_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size) - self.concated_hidden_project = nn.Linear( - self.rnn_hidden_size * 2, self.feature_num - ) - self.fcn_regression = FCN_Regression(feature_num, rnn_hidden_size) - - def gene_hidden_states(self, inputs, direction): - X = inputs[direction]["X"] - masks = inputs[direction]["missing_mask"] - deltas = inputs[direction]["deltas"] - device = X.device - - hidden_states_collector = [] - hidden_state = torch.zeros((X.size()[0], self.rnn_hidden_size), device=device) - - for t in range(self.seq_len): - x = X[:, t, :] - m = masks[:, t, :] - d = deltas[:, t, :] - inputs = torch.cat([x, m, d], dim=1) - if direction == "forward": - hidden_state = self.f_rnn(inputs, hidden_state) - else: - hidden_state = self.b_rnn(inputs, hidden_state) - hidden_states_collector.append(hidden_state) - return hidden_states_collector - - def forward(self, inputs, training=True): - hidden_states_f = self.gene_hidden_states(inputs, "forward") - hidden_states_b = self.gene_hidden_states(inputs, "backward")[::-1] - - X = inputs["forward"]["X"] - masks = inputs["forward"]["missing_mask"] - - reconstruction_loss = 0 - estimations = [] - for i in range( - self.seq_len - ): # calculating estimation loss for times can obtain better results than once - x = X[:, i, :] - m = masks[:, i, :] - h_f = hidden_states_f[i] - h_b = hidden_states_b[i] - h = torch.cat([h_f, h_b], dim=1) - RNN_estimation = self.concated_hidden_project(h) # x̃_t - RNN_imputed_data = m * x + (1 - m) * RNN_estimation - FCN_estimation = self.fcn_regression( - x, m, RNN_imputed_data - ) # FCN estimation is output estimation - reconstruction_loss += cal_rmse(FCN_estimation, x, m) + cal_rmse( - RNN_estimation, x, m - ) - estimations.append(FCN_estimation.unsqueeze(dim=1)) - - estimations = torch.cat(estimations, dim=1) - imputed_data = masks * X + (1 - masks) * estimations - - if not training: - # if not in training mode, return the classification result only - return { - "imputed_data": imputed_data, - } - - reconstruction_loss /= self.seq_len - - ret_dict = { - "loss": reconstruction_loss, - "imputed_data": imputed_data, - } - return ret_dict class MRNN(BaseNNImputer): diff --git a/pypots/imputation/mrnn/modules/__init__.py b/pypots/imputation/mrnn/modules/__init__.py new file mode 100644 index 00000000..27f7ffc6 --- /dev/null +++ b/pypots/imputation/mrnn/modules/__init__.py @@ -0,0 +1,13 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +from .core import _MRNN + +__all__ = [ + "_MRNN", +] diff --git a/pypots/imputation/mrnn/modules/core.py b/pypots/imputation/mrnn/modules/core.py new file mode 100644 index 00000000..1c2e759b --- /dev/null +++ b/pypots/imputation/mrnn/modules/core.py @@ -0,0 +1,97 @@ +""" +PyTorch MRNN model for the time-series imputation task. +Some part of the code is from https://github.com/WenjieDu/SAITS. + +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +import torch +import torch.nn as nn + +from .submodules import FCN_Regression +from ....utils.metrics import cal_rmse + + +class _MRNN(nn.Module): + def __init__(self, seq_len, feature_num, rnn_hidden_size, device): + super().__init__() + # data settings + self.seq_len = seq_len + self.feature_num = feature_num + self.rnn_hidden_size = rnn_hidden_size + self.device = device + + self.f_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size) + self.b_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size) + self.concated_hidden_project = nn.Linear( + self.rnn_hidden_size * 2, self.feature_num + ) + self.fcn_regression = FCN_Regression(feature_num, rnn_hidden_size) + + def gene_hidden_states(self, inputs, direction): + X = inputs[direction]["X"] + masks = inputs[direction]["missing_mask"] + deltas = inputs[direction]["deltas"] + device = X.device + + hidden_states_collector = [] + hidden_state = torch.zeros((X.size()[0], self.rnn_hidden_size), device=device) + + for t in range(self.seq_len): + x = X[:, t, :] + m = masks[:, t, :] + d = deltas[:, t, :] + inputs = torch.cat([x, m, d], dim=1) + if direction == "forward": + hidden_state = self.f_rnn(inputs, hidden_state) + else: + hidden_state = self.b_rnn(inputs, hidden_state) + hidden_states_collector.append(hidden_state) + return hidden_states_collector + + def forward(self, inputs, training=True): + hidden_states_f = self.gene_hidden_states(inputs, "forward") + hidden_states_b = self.gene_hidden_states(inputs, "backward")[::-1] + + X = inputs["forward"]["X"] + masks = inputs["forward"]["missing_mask"] + + reconstruction_loss = 0 + estimations = [] + for i in range( + self.seq_len + ): # calculating estimation loss for times can obtain better results than once + x = X[:, i, :] + m = masks[:, i, :] + h_f = hidden_states_f[i] + h_b = hidden_states_b[i] + h = torch.cat([h_f, h_b], dim=1) + RNN_estimation = self.concated_hidden_project(h) # x̃_t + RNN_imputed_data = m * x + (1 - m) * RNN_estimation + FCN_estimation = self.fcn_regression( + x, m, RNN_imputed_data + ) # FCN estimation is output estimation + reconstruction_loss += cal_rmse(FCN_estimation, x, m) + cal_rmse( + RNN_estimation, x, m + ) + estimations.append(FCN_estimation.unsqueeze(dim=1)) + + estimations = torch.cat(estimations, dim=1) + imputed_data = masks * X + (1 - masks) * estimations + + if not training: + # if not in training mode, return the classification result only + return { + "imputed_data": imputed_data, + } + + reconstruction_loss /= self.seq_len + + ret_dict = { + "loss": reconstruction_loss, + "imputed_data": imputed_data, + } + return ret_dict diff --git a/pypots/imputation/mrnn/module.py b/pypots/imputation/mrnn/modules/submodules.py similarity index 96% rename from pypots/imputation/mrnn/module.py rename to pypots/imputation/mrnn/modules/submodules.py index a143d121..a0e695c5 100644 --- a/pypots/imputation/mrnn/module.py +++ b/pypots/imputation/mrnn/modules/submodules.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from ...imputation.brits.modules import FeatureRegression +from ...brits.modules import FeatureRegression class FCN_Regression(nn.Module): diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py index 21e72cff..8424e7bc 100644 --- a/pypots/imputation/saits/model.py +++ b/pypots/imputation/saits/model.py @@ -13,186 +13,23 @@ # Created by Wenjie Du # License: GPL-v3 -from typing import Tuple, Union, Optional +from typing import Union, Optional, Callable import h5py import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F from torch.utils.data import DataLoader from .data import DatasetForSAITS +from .modules import _SAITS from ..base import BaseNNImputer from ...data.base import BaseDataset -from ...modules.self_attention import EncoderLayer, PositionalEncoding from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger from ...utils.metrics import cal_mae -class _SAITS(nn.Module): - def __init__( - self, - n_layers: int, - n_steps: int, - n_features: int, - d_model: int, - d_inner: int, - n_heads: int, - d_k: int, - d_v: int, - dropout: float, - attn_dropout: float, - diagonal_attention_mask: bool = True, - ORT_weight: float = 1, - MIT_weight: float = 1, - ): - super().__init__() - self.n_layers = n_layers - self.n_steps = n_steps - # concatenate the feature vector and missing mask, hence double the number of features - actual_n_features = n_features * 2 - self.diagonal_attention_mask = diagonal_attention_mask - self.ORT_weight = ORT_weight - self.MIT_weight = MIT_weight - - self.layer_stack_for_first_block = nn.ModuleList( - [ - EncoderLayer( - d_model, - d_inner, - n_heads, - d_k, - d_v, - dropout, - attn_dropout, - ) - for _ in range(n_layers) - ] - ) - self.layer_stack_for_second_block = nn.ModuleList( - [ - EncoderLayer( - d_model, - d_inner, - n_heads, - d_k, - d_v, - dropout, - attn_dropout, - ) - for _ in range(n_layers) - ] - ) - - self.dropout = nn.Dropout(p=dropout) - self.position_enc = PositionalEncoding(d_model, n_position=n_steps) - # for the 1st block - self.embedding_1 = nn.Linear(actual_n_features, d_model) - self.reduce_dim_z = nn.Linear(d_model, n_features) - # for the 2nd block - self.embedding_2 = nn.Linear(actual_n_features, d_model) - self.reduce_dim_beta = nn.Linear(d_model, n_features) - self.reduce_dim_gamma = nn.Linear(n_features, n_features) - # for delta decay factor - self.weight_combine = nn.Linear(n_features + n_steps, n_features) - - def _process( - self, - inputs: dict, - diagonal_attention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, list]: - X, masks = inputs["X"], inputs["missing_mask"] - - # first DMSA block - input_X_for_first = torch.cat([X, masks], dim=2) - input_X_for_first = self.embedding_1(input_X_for_first) - enc_output = self.dropout( - self.position_enc(input_X_for_first) - ) # namely, term e in the math equation - for encoder_layer in self.layer_stack_for_first_block: - enc_output, _ = encoder_layer(enc_output, diagonal_attention_mask) - - X_tilde_1 = self.reduce_dim_z(enc_output) - X_prime = masks * X + (1 - masks) * X_tilde_1 - - # second DMSA block - input_X_for_second = torch.cat([X_prime, masks], dim=2) - input_X_for_second = self.embedding_2(input_X_for_second) - enc_output = self.position_enc( - input_X_for_second - ) # namely term alpha in math algo - attn_weights = None - for encoder_layer in self.layer_stack_for_second_block: - enc_output, attn_weights = encoder_layer(enc_output) - - X_tilde_2 = self.reduce_dim_gamma(F.relu(self.reduce_dim_beta(enc_output))) - - # attention-weighted combine - attn_weights = attn_weights.squeeze(dim=1) # namely term A_hat in Eq. - if len(attn_weights.shape) == 4: - # if having more than 1 head, then average attention weights from all heads - attn_weights = torch.transpose(attn_weights, 1, 3) - attn_weights = attn_weights.mean(dim=3) - attn_weights = torch.transpose(attn_weights, 1, 2) - - # namely term eta - combining_weights = torch.sigmoid( - self.weight_combine(torch.cat([masks, attn_weights], dim=2)) - ) - # combine X_tilde_1 and X_tilde_2 - X_tilde_3 = (1 - combining_weights) * X_tilde_2 + combining_weights * X_tilde_1 - # replace non-missing part with original data - X_c = masks * X + (1 - masks) * X_tilde_3 - - return X_c, [X_tilde_1, X_tilde_2, X_tilde_3] - - def forward( - self, inputs: dict, diagonal_attention_mask: bool = False, training: bool = True - ) -> dict: - X, masks = inputs["X"], inputs["missing_mask"] - - if (training and self.diagonal_attention_mask) or ( - (not training) and diagonal_attention_mask - ): - diagonal_attention_mask = (1 - torch.eye(self.n_steps)).to(X.device) - # then broadcast on the batch axis - diagonal_attention_mask = diagonal_attention_mask.unsqueeze(0) - else: - diagonal_attention_mask = None - - imputed_data, [X_tilde_1, X_tilde_2, X_tilde_3] = self._process( - inputs, diagonal_attention_mask - ) - - if not training: - # if not in training mode, return the classification result only - return { - "imputed_data": imputed_data, - } - - ORT_loss = 0 - ORT_loss += cal_mae(X_tilde_1, X, masks) - ORT_loss += cal_mae(X_tilde_2, X, masks) - ORT_loss += cal_mae(X_tilde_3, X, masks) - ORT_loss /= 3 - - MIT_loss = cal_mae(X_tilde_3, inputs["X_intact"], inputs["indicating_mask"]) - - # `loss` is always the item for backward propagating to update the model - loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss - - results = { - "imputed_data": imputed_data, - "ORT_loss": ORT_loss, - "MIT_loss": MIT_loss, - "loss": loss, # will be used for backward propagating to update the model - } - return results - - class SAITS(BaseNNImputer): """The PyTorch implementation of the SAITS model :cite:`du2023SAITS`. @@ -308,6 +145,7 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + customized_loss_func: Callable = cal_mae, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -358,6 +196,9 @@ def __init__( self._print_model_size() self._send_model_to_given_device() + # set up the loss function + self.customized_loss_func = customized_loss_func + # set up the optimizer self.optimizer = optimizer self.optimizer.init_optimizer(self.model.parameters()) diff --git a/pypots/imputation/saits/modules/__init__.py b/pypots/imputation/saits/modules/__init__.py new file mode 100644 index 00000000..8b93f9a8 --- /dev/null +++ b/pypots/imputation/saits/modules/__init__.py @@ -0,0 +1,13 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +from .core import _SAITS + +__all__ = [ + "_SAITS", +] diff --git a/pypots/imputation/saits/modules/core.py b/pypots/imputation/saits/modules/core.py new file mode 100644 index 00000000..cab54772 --- /dev/null +++ b/pypots/imputation/saits/modules/core.py @@ -0,0 +1,188 @@ +""" +The implementation of SAITS for the partially-observed time-series imputation task. + +Refer to the paper "Du, W., Cote, D., & Liu, Y. (2023). SAITS: Self-Attention-based Imputation for Time Series. +Expert systems with applications." + +Notes +----- +Partial implementation uses code from https://github.com/WenjieDu/SAITS. + +""" + +# Created by Wenjie Du +# License: GPL-v3 + +from typing import Tuple, Optional, Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ....modules.self_attention import EncoderLayer, PositionalEncoding +from ....utils.metrics import cal_mae + + +class _SAITS(nn.Module): + def __init__( + self, + n_layers: int, + n_steps: int, + n_features: int, + d_model: int, + d_inner: int, + n_heads: int, + d_k: int, + d_v: int, + dropout: float, + attn_dropout: float, + diagonal_attention_mask: bool = True, + ORT_weight: float = 1, + MIT_weight: float = 1, + customized_loss_func: Callable = cal_mae, + ): + super().__init__() + self.n_layers = n_layers + self.n_steps = n_steps + # concatenate the feature vector and missing mask, hence double the number of features + actual_n_features = n_features * 2 + self.diagonal_attention_mask = diagonal_attention_mask + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight + self.customized_loss_func = customized_loss_func + + self.layer_stack_for_first_block = nn.ModuleList( + [ + EncoderLayer( + d_model, + d_inner, + n_heads, + d_k, + d_v, + dropout, + attn_dropout, + ) + for _ in range(n_layers) + ] + ) + self.layer_stack_for_second_block = nn.ModuleList( + [ + EncoderLayer( + d_model, + d_inner, + n_heads, + d_k, + d_v, + dropout, + attn_dropout, + ) + for _ in range(n_layers) + ] + ) + + self.dropout = nn.Dropout(p=dropout) + self.position_enc = PositionalEncoding(d_model, n_position=n_steps) + # for the 1st block + self.embedding_1 = nn.Linear(actual_n_features, d_model) + self.reduce_dim_z = nn.Linear(d_model, n_features) + # for the 2nd block + self.embedding_2 = nn.Linear(actual_n_features, d_model) + self.reduce_dim_beta = nn.Linear(d_model, n_features) + self.reduce_dim_gamma = nn.Linear(n_features, n_features) + # for delta decay factor + self.weight_combine = nn.Linear(n_features + n_steps, n_features) + + def _process( + self, + inputs: dict, + diagonal_attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, list]: + X, masks = inputs["X"], inputs["missing_mask"] + + # first DMSA block + input_X_for_first = torch.cat([X, masks], dim=2) + input_X_for_first = self.embedding_1(input_X_for_first) + enc_output = self.dropout( + self.position_enc(input_X_for_first) + ) # namely, term e in the math equation + for encoder_layer in self.layer_stack_for_first_block: + enc_output, _ = encoder_layer(enc_output, diagonal_attention_mask) + + X_tilde_1 = self.reduce_dim_z(enc_output) + X_prime = masks * X + (1 - masks) * X_tilde_1 + + # second DMSA block + input_X_for_second = torch.cat([X_prime, masks], dim=2) + input_X_for_second = self.embedding_2(input_X_for_second) + enc_output = self.position_enc( + input_X_for_second + ) # namely term alpha in math algo + attn_weights = None + for encoder_layer in self.layer_stack_for_second_block: + enc_output, attn_weights = encoder_layer(enc_output) + + X_tilde_2 = self.reduce_dim_gamma(F.relu(self.reduce_dim_beta(enc_output))) + + # attention-weighted combine + attn_weights = attn_weights.squeeze(dim=1) # namely term A_hat in Eq. + if len(attn_weights.shape) == 4: + # if having more than 1 head, then average attention weights from all heads + attn_weights = torch.transpose(attn_weights, 1, 3) + attn_weights = attn_weights.mean(dim=3) + attn_weights = torch.transpose(attn_weights, 1, 2) + + # namely term eta + combining_weights = torch.sigmoid( + self.weight_combine(torch.cat([masks, attn_weights], dim=2)) + ) + # combine X_tilde_1 and X_tilde_2 + X_tilde_3 = (1 - combining_weights) * X_tilde_2 + combining_weights * X_tilde_1 + # replace non-missing part with original data + X_c = masks * X + (1 - masks) * X_tilde_3 + + return X_c, [X_tilde_1, X_tilde_2, X_tilde_3] + + def forward( + self, inputs: dict, diagonal_attention_mask: bool = False, training: bool = True + ) -> dict: + X, masks = inputs["X"], inputs["missing_mask"] + + if (training and self.diagonal_attention_mask) or ( + (not training) and diagonal_attention_mask + ): + diagonal_attention_mask = (1 - torch.eye(self.n_steps)).to(X.device) + # then broadcast on the batch axis + diagonal_attention_mask = diagonal_attention_mask.unsqueeze(0) + else: + diagonal_attention_mask = None + + imputed_data, [X_tilde_1, X_tilde_2, X_tilde_3] = self._process( + inputs, diagonal_attention_mask + ) + + if not training: + # if not in training mode, return the classification result only + return { + "imputed_data": imputed_data, + } + + ORT_loss = 0 + ORT_loss += self.customized_loss_func(X_tilde_1, X, masks) + ORT_loss += self.customized_loss_func(X_tilde_2, X, masks) + ORT_loss += self.customized_loss_func(X_tilde_3, X, masks) + ORT_loss /= 3 + + MIT_loss = self.customized_loss_func( + X_tilde_3, inputs["X_intact"], inputs["indicating_mask"] + ) + + # `loss` is always the item for backward propagating to update the model + loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss + + results = { + "imputed_data": imputed_data, + "ORT_loss": ORT_loss, + "MIT_loss": MIT_loss, + "loss": loss, # will be used for backward propagating to update the model + } + return results diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py index e7f9c02b..bd4d0fd4 100644 --- a/pypots/imputation/transformer/model.py +++ b/pypots/imputation/transformer/model.py @@ -13,106 +13,20 @@ # Created by Wenjie Du # License: GPL-v3 -from typing import Tuple, Union, Optional +from typing import Union, Optional import h5py import numpy as np import torch -import torch.nn as nn from torch.utils.data import DataLoader from .data import DatasetForSAITS +from .modules import _TransformerEncoder from ..base import BaseNNImputer from ...data.base import BaseDataset -from ...modules.self_attention import EncoderLayer, PositionalEncoding from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import cal_mae - - -class _TransformerEncoder(nn.Module): - def __init__( - self, - n_layers: int, - d_time: int, - d_feature: int, - d_model: int, - d_inner: int, - n_heads: int, - d_k: int, - d_v: int, - dropout: float, - attn_dropout: float, - ORT_weight: float = 1, - MIT_weight: float = 1, - ): - super().__init__() - self.n_layers = n_layers - actual_d_feature = d_feature * 2 - self.ORT_weight = ORT_weight - self.MIT_weight = MIT_weight - - self.layer_stack = nn.ModuleList( - [ - EncoderLayer( - d_model, - d_inner, - n_heads, - d_k, - d_v, - dropout, - attn_dropout, - ) - for _ in range(n_layers) - ] - ) - - self.embedding = nn.Linear(actual_d_feature, d_model) - self.position_enc = PositionalEncoding(d_model, n_position=d_time) - self.dropout = nn.Dropout(p=dropout) - self.reduce_dim = nn.Linear(d_model, d_feature) - - def _process(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]: - X, masks = inputs["X"], inputs["missing_mask"] - input_X = torch.cat([X, masks], dim=2) - input_X = self.embedding(input_X) - enc_output = self.dropout(self.position_enc(input_X)) - - for encoder_layer in self.layer_stack: - enc_output, _ = encoder_layer(enc_output) - - learned_presentation = self.reduce_dim(enc_output) - imputed_data = ( - masks * X + (1 - masks) * learned_presentation - ) # replace non-missing part with original data - return imputed_data, learned_presentation - - def forward(self, inputs: dict, training: bool = True) -> dict: - X, masks = inputs["X"], inputs["missing_mask"] - imputed_data, learned_presentation = self._process(inputs) - - if not training: - # if not in training mode, return the classification result only - return { - "imputed_data": imputed_data, - } - - ORT_loss = cal_mae(learned_presentation, X, masks) - MIT_loss = cal_mae( - learned_presentation, inputs["X_intact"], inputs["indicating_mask"] - ) - - # `loss` is always the item for backward propagating to update the model - loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss - - results = { - "imputed_data": imputed_data, - "ORT_loss": ORT_loss, - "MIT_loss": MIT_loss, - "loss": loss, - } - return results class Transformer(BaseNNImputer): diff --git a/pypots/imputation/transformer/modules/__init__.py b/pypots/imputation/transformer/modules/__init__.py new file mode 100644 index 00000000..d05cfea2 --- /dev/null +++ b/pypots/imputation/transformer/modules/__init__.py @@ -0,0 +1,13 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +from .core import _TransformerEncoder + +__all__ = [ + "_TransformerEncoder", +] diff --git a/pypots/imputation/transformer/modules/core.py b/pypots/imputation/transformer/modules/core.py new file mode 100644 index 00000000..c017b4cc --- /dev/null +++ b/pypots/imputation/transformer/modules/core.py @@ -0,0 +1,106 @@ +""" +The implementation of Transformer for the partially-observed time-series imputation task. + +Refer to the paper "Du, W., Cote, D., & Liu, Y. (2023). SAITS: Self-Attention-based Imputation for Time Series. +Expert systems with applications." + +Notes +----- +Partial implementation uses code from https://github.com/WenjieDu/SAITS. + +""" + +# Created by Wenjie Du +# License: GPL-v3 + +from typing import Tuple + +import torch +import torch.nn as nn + +from ....modules.self_attention import EncoderLayer, PositionalEncoding +from ....utils.metrics import cal_mae + + +class _TransformerEncoder(nn.Module): + def __init__( + self, + n_layers: int, + d_time: int, + d_feature: int, + d_model: int, + d_inner: int, + n_heads: int, + d_k: int, + d_v: int, + dropout: float, + attn_dropout: float, + ORT_weight: float = 1, + MIT_weight: float = 1, + ): + super().__init__() + self.n_layers = n_layers + actual_d_feature = d_feature * 2 + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight + + self.layer_stack = nn.ModuleList( + [ + EncoderLayer( + d_model, + d_inner, + n_heads, + d_k, + d_v, + dropout, + attn_dropout, + ) + for _ in range(n_layers) + ] + ) + + self.embedding = nn.Linear(actual_d_feature, d_model) + self.position_enc = PositionalEncoding(d_model, n_position=d_time) + self.dropout = nn.Dropout(p=dropout) + self.reduce_dim = nn.Linear(d_model, d_feature) + + def _process(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]: + X, masks = inputs["X"], inputs["missing_mask"] + input_X = torch.cat([X, masks], dim=2) + input_X = self.embedding(input_X) + enc_output = self.dropout(self.position_enc(input_X)) + + for encoder_layer in self.layer_stack: + enc_output, _ = encoder_layer(enc_output) + + learned_presentation = self.reduce_dim(enc_output) + imputed_data = ( + masks * X + (1 - masks) * learned_presentation + ) # replace non-missing part with original data + return imputed_data, learned_presentation + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, masks = inputs["X"], inputs["missing_mask"] + imputed_data, learned_presentation = self._process(inputs) + + if not training: + # if not in training mode, return the classification result only + return { + "imputed_data": imputed_data, + } + + ORT_loss = cal_mae(learned_presentation, X, masks) + MIT_loss = cal_mae( + learned_presentation, inputs["X_intact"], inputs["indicating_mask"] + ) + + # `loss` is always the item for backward propagating to update the model + loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss + + results = { + "imputed_data": imputed_data, + "ORT_loss": ORT_loss, + "MIT_loss": MIT_loss, + "loss": loss, + } + return results diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index f1a6ab27..9162834a 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -14,157 +14,16 @@ import h5py import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F from torch.utils.data import DataLoader from .data import DatasetForUSGAN +from .modules import _USGAN from ..base import BaseNNImputer -from ..brits.model import _BRITS from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -class Discriminator(nn.Module): - """model Discriminator: built on BiRNN - - Parameters - ---------- - n_features : - the feature dimension of the input - - rnn_hidden_size : - the hidden size of the RNN cell - - hint_rate : - the hint rate for the input imputed_data - - dropout_rate : - the dropout rate for the output layer - - device : - specify running the model on which device, CPU/GPU - - """ - - def __init__( - self, - n_features: int, - rnn_hidden_size: int, - hint_rate: float = 0.7, - dropout_rate: float = 0.0, - device: Union[str, torch.device] = "cpu", - ): - super().__init__() - self.hint_rate = hint_rate - self.device = device - self.biRNN = nn.GRU( - n_features * 2, rnn_hidden_size, bidirectional=True, batch_first=True - ).to(device) - self.dropout = nn.Dropout(dropout_rate).to(device) - self.read_out = nn.Linear(rnn_hidden_size * 2, n_features).to(device) - - def forward( - self, - imputed_X: torch.Tensor, - missing_mask: torch.Tensor, - ) -> torch.Tensor: - """Forward processing of USGAN Discriminator. - - Parameters - ---------- - imputed_X : torch.Tensor, - The original X with missing parts already imputed. - - missing_mask : torch.Tensor, - The missing mask of X. - - Returns - ------- - logits : torch.Tensor, - the logits of the probability of being the true value. - - """ - - hint = ( - torch.rand_like(missing_mask, dtype=torch.float, device=self.device) - < self.hint_rate - ) - hint = hint.int() - h = hint * missing_mask + (1 - hint) * 0.5 - x_in = torch.cat([imputed_X, h], dim=-1) - - out, _ = self.biRNN(x_in) - logits = self.read_out(self.dropout(out)) - return logits - - -class _USGAN(nn.Module): - """USGAN model""" - - def __init__( - self, - n_steps: int, - n_features: int, - rnn_hidden_size: int, - lambda_mse: float, - hint_rate: float = 0.7, - dropout_rate: float = 0.0, - device: Union[str, torch.device] = "cpu", - ): - super().__init__() - self.generator = _BRITS(n_steps, n_features, rnn_hidden_size, device) - self.discriminator = Discriminator( - n_features, - rnn_hidden_size, - hint_rate=hint_rate, - dropout_rate=dropout_rate, - device=device, - ) - - self.lambda_mse = lambda_mse - self.device = device - - def forward( - self, - inputs: dict, - training_object: str = "generator", - training: bool = True, - ) -> dict: - assert training_object in [ - "generator", - "discriminator", - ], 'training_object should be "generator" or "discriminator"' - - forward_X = inputs["forward"]["X"] - forward_missing_mask = inputs["forward"]["missing_mask"] - losses = {} - results = self.generator(inputs, training=training) - inputs["discrimination"] = self.discriminator(forward_X, forward_missing_mask) - if not training: - # if only run imputation operation, then no need to calculate loss - return results - - if training_object == "discriminator": - l_D = F.binary_cross_entropy_with_logits( - inputs["discrimination"], forward_missing_mask - ) - losses["discrimination_loss"] = l_D - else: - inputs["discrimination"] = inputs["discrimination"].detach() - l_G = F.binary_cross_entropy_with_logits( - inputs["discrimination"], - 1 - forward_missing_mask, - weight=1 - forward_missing_mask, - ) - loss_gene = l_G + self.lambda_mse * results["loss"] - losses["generation_loss"] = loss_gene - - losses["imputed_data"] = results["imputed_data"] - return losses - - class USGAN(BaseNNImputer): """The PyTorch implementation of the USGAN model. Refer to :cite:`miao2021SSGAN`. diff --git a/pypots/imputation/usgan/modules/__init__.py b/pypots/imputation/usgan/modules/__init__.py new file mode 100644 index 00000000..2b148245 --- /dev/null +++ b/pypots/imputation/usgan/modules/__init__.py @@ -0,0 +1,13 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +from .core import _USGAN + +__all__ = [ + "_USGAN", +] diff --git a/pypots/imputation/usgan/modules/core.py b/pypots/imputation/usgan/modules/core.py new file mode 100644 index 00000000..3f1406f5 --- /dev/null +++ b/pypots/imputation/usgan/modules/core.py @@ -0,0 +1,84 @@ +""" +The implementation of USGAN for the partially-observed time-series imputation task. + +Refer to the paper "Miao, X., Wu, Y., Wang, J., Gao, Y., Mao, X., & Yin, J. (2021). +Generative Semi-supervised Learning for Multivariate Time Series Imputation. AAAI 2021." + +""" + +# Created by Jun Wang and Wenjie Du +# License: GPL-v3 + +from typing import Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .submodules import Discriminator +from ...brits.modules import _BRITS + + +class _USGAN(nn.Module): + """USGAN model""" + + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + lambda_mse: float, + hint_rate: float = 0.7, + dropout_rate: float = 0.0, + device: Union[str, torch.device] = "cpu", + ): + super().__init__() + self.generator = _BRITS(n_steps, n_features, rnn_hidden_size, device) + self.discriminator = Discriminator( + n_features, + rnn_hidden_size, + hint_rate=hint_rate, + dropout_rate=dropout_rate, + device=device, + ) + + self.lambda_mse = lambda_mse + self.device = device + + def forward( + self, + inputs: dict, + training_object: str = "generator", + training: bool = True, + ) -> dict: + assert training_object in [ + "generator", + "discriminator", + ], 'training_object should be "generator" or "discriminator"' + + forward_X = inputs["forward"]["X"] + forward_missing_mask = inputs["forward"]["missing_mask"] + losses = {} + results = self.generator(inputs, training=training) + inputs["discrimination"] = self.discriminator(forward_X, forward_missing_mask) + if not training: + # if only run imputation operation, then no need to calculate loss + return results + + if training_object == "discriminator": + l_D = F.binary_cross_entropy_with_logits( + inputs["discrimination"], forward_missing_mask + ) + losses["discrimination_loss"] = l_D + else: + inputs["discrimination"] = inputs["discrimination"].detach() + l_G = F.binary_cross_entropy_with_logits( + inputs["discrimination"], + 1 - forward_missing_mask, + weight=1 - forward_missing_mask, + ) + loss_gene = l_G + self.lambda_mse * results["loss"] + losses["generation_loss"] = loss_gene + + losses["imputed_data"] = results["imputed_data"] + return losses diff --git a/pypots/imputation/usgan/modules/submodules.py b/pypots/imputation/usgan/modules/submodules.py new file mode 100644 index 00000000..18a52f15 --- /dev/null +++ b/pypots/imputation/usgan/modules/submodules.py @@ -0,0 +1,85 @@ +""" + +""" + +# Created by Jun Wang and Wenjie Du +# License: GPL-v3 + +from typing import Union + +import torch +import torch.nn as nn + + +class Discriminator(nn.Module): + """model Discriminator: built on BiRNN + + Parameters + ---------- + n_features : + the feature dimension of the input + + rnn_hidden_size : + the hidden size of the RNN cell + + hint_rate : + the hint rate for the input imputed_data + + dropout_rate : + the dropout rate for the output layer + + device : + specify running the model on which device, CPU/GPU + + """ + + def __init__( + self, + n_features: int, + rnn_hidden_size: int, + hint_rate: float = 0.7, + dropout_rate: float = 0.0, + device: Union[str, torch.device] = "cpu", + ): + super().__init__() + self.hint_rate = hint_rate + self.device = device + self.biRNN = nn.GRU( + n_features * 2, rnn_hidden_size, bidirectional=True, batch_first=True + ).to(device) + self.dropout = nn.Dropout(dropout_rate).to(device) + self.read_out = nn.Linear(rnn_hidden_size * 2, n_features).to(device) + + def forward( + self, + imputed_X: torch.Tensor, + missing_mask: torch.Tensor, + ) -> torch.Tensor: + """Forward processing of USGAN Discriminator. + + Parameters + ---------- + imputed_X : torch.Tensor, + The original X with missing parts already imputed. + + missing_mask : torch.Tensor, + The missing mask of X. + + Returns + ------- + logits : torch.Tensor, + the logits of the probability of being the true value. + + """ + + hint = ( + torch.rand_like(missing_mask, dtype=torch.float, device=self.device) + < self.hint_rate + ) + hint = hint.int() + h = hint * missing_mask + (1 - hint) * 0.5 + x_in = torch.cat([imputed_X, h], dim=-1) + + out, _ = self.biRNN(x_in) + logits = self.read_out(self.dropout(out)) + return logits