diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index cba82727..dcd8490a 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -49,4 +49,6 @@ body: required: true attributes: label: 4. Expected behavior - description: "A clear and concise description of what error you would expect to happen." + description: | + A clear and concise description of what error you would expect to happen. + Please provide the whole "Traceback" of the error message printed in the console. diff --git a/README.md b/README.md index d34d53c8..4186ce2e 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ -

Welcome to PyPOTS

+

Welcome to PyPOTS

a Python toolbox for machine learning on Partially-Observed Time Series

@@ -213,10 +213,6 @@ This functionality is implemented with the [Microsoft NNI](https://github.com/mi ## ❖ Citing PyPOTS -**[Updates in Jun 2023]** πŸŽ‰A short version of the PyPOTS paper is accepted by the 9th SIGKDD international workshop on -Mining and Learning from Time Series ([MiLeTS'23](https://kdd-milets.github.io/milets2023/))). -Besides, PyPOTS has been included as a [PyTorch Ecosystem](https://pytorch.org/ecosystem/) project. - The paper introducing PyPOTS is available on arXiv at [this URL](https://arxiv.org/abs/2305.18811), and we are pursuing to publish it in prestigious academic venues, e.g. JMLR (track for [Machine Learning Open Source Software](https://www.jmlr.org/mloss/)). If you use PyPOTS in your work, @@ -243,6 +239,17 @@ doi={10.48550/arXiv.2305.18811}, > arXiv, abs/2305.18811.https://arxiv.org/abs/2305.18811 +> [!TIP] +> **[Updates in Feb 2024]** 😎 Our survey paper [Deep Learning for Multivariate Time Series Imputation: A Survey](https://arxiv.org/abs/2402.04059) has been released on arXiv. +The code is open source in the GitHub repo [Awesome_Imputation](https://github.com/WenjieDu/Awesome_Imputation). +We comprehensively review the literature of the state-of-the-art deep-learning imputation methods for time series, +provide a taxonomy for them, and discuss the challenges and future directions in this field. +> +> **[Updates in Jun 2023]** πŸŽ‰ A short version of the PyPOTS paper is accepted by the 9th SIGKDD international workshop on +Mining and Learning from Time Series ([MiLeTS'23](https://kdd-milets.github.io/milets2023/))). +Besides, PyPOTS has been included as a [PyTorch Ecosystem](https://pytorch.org/ecosystem/) project. + + ## ❖ Contribution You're very welcome to contribute to this exciting project! diff --git a/docs/about_us.rst b/docs/about_us.rst index 3a2e43f8..e5fdc1fb 100644 --- a/docs/about_us.rst +++ b/docs/about_us.rst @@ -6,22 +6,22 @@ Core Development Team Wenjie Du ********** -- Initialized the project in March 2022 +- Founded the organization in March 2022 - `GitHub (WenjieDu) `_ - `LinkedIn (Wenjie Du) `_ -Maciej Skrabski -*************** -- Joined in May 2023 -- `GitHub (MaciejSkrabski) `_ -- `LinkedIn (Maciej Skrabski) `_ - Jun Wang ******** - Joined in August 2023 - `GitHub (AugustJW) `_ - `LinkedIn (Jun Wang) `_ +Linglong Qian +******** +- Joined in February 2024 +- `GitHub (LinglongQian) `_ +- `LinkedIn (Linglong Qian) `_ + All Contributors """""""""""""""" diff --git a/docs/index.rst b/docs/index.rst index b98c384e..59bc6a11 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -235,10 +235,10 @@ Your star is your recognition to PyPOTS, and it matters! The lists of PyPOTS stargazers and forkers are shown below, and we're so proud to have more and more awesome users, as well as more bright ✨stars: -.. image:: http://reporoster.com/stars/dark/WenjieDu/PyPOTS +.. image:: https://bytecrank.com/nastyox/reporoster/php/stargazersSVG.php?theme=dark&user=WenjieDu&repo=PyPOTS :alt: PyPOTS stargazers :target: https://github.com/WenjieDu/PyPOTS/stargazers -.. image:: http://reporoster.com/forks/dark/WenjieDu/PyPOTS +.. image:: https://bytecrank.com/nastyox/reporoster/php/forkersSVG.php?theme=dark&user=WenjieDu&repo=PyPOTS :alt: PyPOTS forkers :target: https://github.com/WenjieDu/PyPOTS/network/members diff --git a/pypots/classification/grud/data.py b/pypots/classification/grud/data.py index a4e4a163..34865428 100644 --- a/pypots/classification/grud/data.py +++ b/pypots/classification/grud/data.py @@ -12,7 +12,7 @@ from ...data.base import BaseDataset from ...data.utils import _parse_delta_torch -from ...imputation.locf import LOCF +from ...imputation.locf import locf_torch class DatasetForGRUD(BaseDataset): @@ -49,11 +49,9 @@ def __init__( file_type: str = "h5py", ): super().__init__(data, False, return_labels, file_type) - self.locf = LOCF() - if not isinstance(self.data, str): # data from array self.missing_mask = (~torch.isnan(self.X)).to(torch.float32) - self.X_filledLOCF = self.locf._locf_torch(self.X) + self.X_filledLOCF = locf_torch(self.X) self.X = torch.nan_to_num(self.X) self.deltas = _parse_delta_torch(self.missing_mask) self.empirical_mean = torch.sum( @@ -125,7 +123,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32) missing_mask = (~torch.isnan(X)).to(torch.float32) - X_filledLOCF = self.locf._locf_torch(X.unsqueeze(dim=0)).squeeze() + X_filledLOCF = locf_torch(X.unsqueeze(dim=0)).squeeze() X = torch.nan_to_num(X) deltas = _parse_delta_torch(missing_mask) empirical_mean = torch.sum(missing_mask * X, dim=[0]) / torch.sum( diff --git a/pypots/imputation/gpvae/modules/submodules.py b/pypots/imputation/gpvae/modules/submodules.py index 3a63fcee..98d1f00a 100644 --- a/pypots/imputation/gpvae/modules/submodules.py +++ b/pypots/imputation/gpvae/modules/submodules.py @@ -106,14 +106,14 @@ def forward(self, x): if len(x.shape) > 2: shape = list(np.arange(len(x.shape))) new_shape = [0, shape[-1]] + shape[1:-1] - out = super(CustomConv1d, self).forward(x.permute(*new_shape)) + out = super().forward(x.permute(*new_shape)) shape = list(np.arange(len(out.shape))) new_shape = [0, shape[-1]] + shape[1:-1] if self.kernel_size[0] % 2 == 0: out = F.pad(out, (0, -1), "constant", 0) return out.permute(new_shape) - return super(CustomConv1d, self).forward(x) + return super().forward(x) def make_cnn(input_size, output_size, hidden_sizes, kernel_size=3): diff --git a/pypots/imputation/locf/__init__.py b/pypots/imputation/locf/__init__.py index 9536f6cb..82f7772f 100644 --- a/pypots/imputation/locf/__init__.py +++ b/pypots/imputation/locf/__init__.py @@ -5,8 +5,10 @@ # Created by Wenjie Du # License: BSD-3-Clause -from .model import LOCF +from .model import LOCF, locf_numpy, locf_torch __all__ = [ "LOCF", + "locf_numpy", + "locf_torch", ] diff --git a/pypots/imputation/locf/model.py b/pypots/imputation/locf/model.py index 38d19d6b..f634f87f 100644 --- a/pypots/imputation/locf/model.py +++ b/pypots/imputation/locf/model.py @@ -13,6 +13,7 @@ import numpy as np import torch +from .modules.core import locf_numpy, locf_torch from ..base import BaseImputer from ...utils.logging import logger @@ -68,122 +69,6 @@ def fit( "Please run func `predict()` directly." ) - def _locf_numpy( - self, - X: np.ndarray, - first_step_imputation: str = "backward", - ) -> np.ndarray: - """Numpy implementation of LOCF. - - Parameters - ---------- - X : np.ndarray, - Time series containing missing values (NaN) to be imputed. - - Returns - ------- - X_imputed : array, - Imputed time series. - - Notes - ----- - This implementation gets inspired by the question on StackOverflow: - https://stackoverflow.com/questions/41190852/most-efficient-way-to-forward-fill-nan-values-in-numpy-array - - """ - trans_X = X.transpose((0, 2, 1)) - mask = np.isnan(trans_X) - n_samples, n_steps, n_features = mask.shape - idx = np.where(~mask, np.arange(n_features), 0) - idx = np.maximum.accumulate(idx, axis=2) - - collector = [] - for x, i in zip(trans_X, idx): - collector.append(x[np.arange(n_steps)[:, None], i]) - X_imputed = np.asarray(collector) - X_imputed = X_imputed.transpose((0, 2, 1)) - - # If there are values still missing, they are missing at the beginning of the time-series sequence. - if np.isnan(X_imputed).any(): - if first_step_imputation == "nan": - pass - elif first_step_imputation == "zero": - X_imputed = np.nan_to_num(X_imputed, nan=0) - elif first_step_imputation == "backward": - # imputed by last observation carried backward (LOCB) - X_imputed_transpose = np.copy(X_imputed) - X_imputed_transpose = np.flip(X_imputed_transpose, axis=1) - X_LOCB = self._locf_numpy( - X_imputed_transpose, - "zero", - ) - X_imputed = np.flip(X_LOCB, axis=1) - elif first_step_imputation == "median": - bz, n_steps, n_features = X_imputed.shape - X_imputed_reshaped = np.copy(X_imputed).reshape(-1, n_features) - median_values = np.nanmedian(X_imputed_reshaped, axis=0) - for i, v in enumerate(median_values): - X_imputed[:, :, i] = np.nan_to_num(X_imputed[:, :, i], nan=v) - if np.isnan(X_imputed).any() and self.keep_trying: - X_imputed = np.nan_to_num(X_imputed, nan=0) - - return X_imputed - - def _locf_torch( - self, - X: torch.Tensor, - first_step_imputation: str = "backward", - ) -> torch.Tensor: - """Torch implementation of LOCF. - - Parameters - ---------- - X : tensor, - Time series containing missing values (NaN) to be imputed. - - Returns - ------- - X_imputed : tensor, - Imputed time series. - """ - trans_X = X.permute((0, 2, 1)) - mask = torch.isnan(trans_X) - n_samples, n_steps, n_features = mask.shape - idx = torch.where(~mask, torch.arange(n_features, device=mask.device), 0) - idx = np.maximum.accumulate(idx, axis=2) - - collector = [] - for x, i in zip(trans_X, idx): - collector.append(x[torch.arange(n_steps)[:, None], i]) - X_imputed = torch.stack(collector) - X_imputed = X_imputed.permute((0, 2, 1)) - - # If there are values still missing, they are missing at the beginning of the time-series sequence. - if torch.isnan(X_imputed).any(): - if first_step_imputation == "nan": - pass - elif first_step_imputation == "zero": - X_imputed = torch.nan_to_num(X_imputed, nan=0) - elif first_step_imputation == "backward": - # imputed by last observation carried backward (LOCB) - X_imputed_transpose = X_imputed.clone() - X_imputed_transpose = torch.flip(X_imputed_transpose, dims=[1]) - X_LOCB = self._locf_torch( - X_imputed_transpose, - "zero", - ) - X_imputed = torch.flip(X_LOCB, dims=[1]) - elif first_step_imputation == "median": - bz, n_steps, n_features = X_imputed.shape - X_imputed_reshaped = X_imputed.clone().reshape(-1, n_features) - median_values = torch.nanmedian(X_imputed_reshaped, dim=0) - for i, v in enumerate(median_values.values): - X_imputed[:, :, i] = torch.nan_to_num(X_imputed[:, :, i], nan=v) - if torch.isnan(X_imputed).any() and self.keep_trying: - X_imputed = torch.nan_to_num(X_imputed, nan=0) - - return X_imputed - def predict( self, test_set: Union[dict, str], @@ -226,9 +111,9 @@ def predict( X = np.asarray(X) if isinstance(X, np.ndarray): - imputed_data = self._locf_numpy(X, self.first_step_imputation) + imputed_data = locf_numpy(X, self.first_step_imputation) elif isinstance(X, torch.Tensor): - imputed_data = self._locf_torch(X, self.first_step_imputation) + imputed_data = locf_torch(X, self.first_step_imputation) else: raise TypeError( "X must be type of list/np.ndarray/torch.Tensor, " f"but got {type(X)}" diff --git a/pypots/imputation/locf/modules/__init__.py b/pypots/imputation/locf/modules/__init__.py new file mode 100644 index 00000000..ceaa7ee3 --- /dev/null +++ b/pypots/imputation/locf/modules/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause diff --git a/pypots/imputation/locf/modules/core.py b/pypots/imputation/locf/modules/core.py new file mode 100644 index 00000000..8c12ffc9 --- /dev/null +++ b/pypots/imputation/locf/modules/core.py @@ -0,0 +1,155 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import numpy as np +import torch + + +def locf_numpy( + X: np.ndarray, + first_step_imputation: str = "backward", +) -> np.ndarray: + """Numpy implementation of LOCF. + + Parameters + ---------- + X : np.ndarray, + Time series containing missing values (NaN) to be imputed. + + first_step_imputation : str, default='backward' + With LOCF, the observed values are carried forward to impute the missing ones. But if the first value + is missing, there is no value to carry forward. This parameter is used to determine the strategy to + impute the missing values at the beginning of the time-series sequence after LOCF is applied. + It can be one of ['backward', 'zero', 'median', 'nan']. + If 'nan', the missing values at the sequence beginning will be left as NaNs. + If 'zero', the missing values at the sequence beginning will be imputed with 0. + If 'backward', the missing values at the beginning of the time-series sequence will be imputed with the + first observed value in the sequence, i.e. the first observed value will be carried backward to impute + the missing values at the beginning of the sequence. This method is also known as NOCB (Next Observation + Carried Backward). If 'median', the missing values at the sequence beginning will be imputed with the overall + median values of features in the dataset. + If `first_step_imputation` is not "nan", if missing values still exist (this is usually caused by whole feature + missing) after applying `first_step_imputation`, they will be filled with 0. + + Returns + ------- + X_imputed : array, + Imputed time series. + + Notes + ----- + This implementation gets inspired by the question on StackOverflow: + https://stackoverflow.com/questions/41190852/most-efficient-way-to-forward-fill-nan-values-in-numpy-array + + """ + trans_X = X.transpose((0, 2, 1)) + mask = np.isnan(trans_X) + n_samples, n_steps, n_features = mask.shape + idx = np.where(~mask, np.arange(n_features), 0) + idx = np.maximum.accumulate(idx, axis=2) + + collector = [] + for x, i in zip(trans_X, idx): + collector.append(x[np.arange(n_steps)[:, None], i]) + X_imputed = np.asarray(collector) + X_imputed = X_imputed.transpose((0, 2, 1)) + + # If there are values still missing, they are missing at the beginning of the time-series sequence. + if np.isnan(X_imputed).any(): + if first_step_imputation == "nan": + pass + elif first_step_imputation == "zero": + X_imputed = np.nan_to_num(X_imputed, nan=0) + elif first_step_imputation == "backward": + # imputed by last observation carried backward (LOCB) + X_imputed_transpose = np.copy(X_imputed) + X_imputed_transpose = np.flip(X_imputed_transpose, axis=1) + X_LOCB = locf_numpy( + X_imputed_transpose, + "zero", + ) + X_imputed = np.flip(X_LOCB, axis=1) + elif first_step_imputation == "median": + bz, n_steps, n_features = X_imputed.shape + X_imputed_reshaped = np.copy(X_imputed).reshape(-1, n_features) + median_values = np.nanmedian(X_imputed_reshaped, axis=0) + for i, v in enumerate(median_values): + X_imputed[:, :, i] = np.nan_to_num(X_imputed[:, :, i], nan=v) + if np.isnan(X_imputed).any(): + X_imputed = np.nan_to_num(X_imputed, nan=0) + + return X_imputed + + +def locf_torch( + X: torch.Tensor, + first_step_imputation: str = "backward", +) -> torch.Tensor: + """Torch implementation of LOCF. + + Parameters + ---------- + X : tensor, + Time series containing missing values (NaN) to be imputed. + + first_step_imputation : str, default='backward' + With LOCF, the observed values are carried forward to impute the missing ones. But if the first value + is missing, there is no value to carry forward. This parameter is used to determine the strategy to + impute the missing values at the beginning of the time-series sequence after LOCF is applied. + It can be one of ['backward', 'zero', 'median', 'nan']. + If 'nan', the missing values at the sequence beginning will be left as NaNs. + If 'zero', the missing values at the sequence beginning will be imputed with 0. + If 'backward', the missing values at the beginning of the time-series sequence will be imputed with the + first observed value in the sequence, i.e. the first observed value will be carried backward to impute + the missing values at the beginning of the sequence. This method is also known as NOCB (Next Observation + Carried Backward). If 'median', the missing values at the sequence beginning will be imputed with the overall + median values of features in the dataset. + If `first_step_imputation` is not "nan", if missing values still exist (this is usually caused by whole feature + missing) after applying `first_step_imputation`, they will be filled with 0. + + Returns + ------- + X_imputed : tensor, + Imputed time series. + """ + trans_X = X.permute((0, 2, 1)) + mask = torch.isnan(trans_X) + n_samples, n_steps, n_features = mask.shape + idx = torch.where(~mask, torch.arange(n_features, device=mask.device), 0) + idx = np.maximum.accumulate(idx, axis=2) + + collector = [] + for x, i in zip(trans_X, idx): + collector.append(x[torch.arange(n_steps)[:, None], i]) + X_imputed = torch.stack(collector) + X_imputed = X_imputed.permute((0, 2, 1)) + + # If there are values still missing, they are missing at the beginning of the time-series sequence. + if torch.isnan(X_imputed).any(): + if first_step_imputation == "nan": + pass + elif first_step_imputation == "zero": + X_imputed = torch.nan_to_num(X_imputed, nan=0) + elif first_step_imputation == "backward": + # imputed by last observation carried backward (LOCB) + X_imputed_transpose = X_imputed.clone() + X_imputed_transpose = torch.flip(X_imputed_transpose, dims=[1]) + X_LOCB = locf_torch( + X_imputed_transpose, + "zero", + ) + X_imputed = torch.flip(X_LOCB, dims=[1]) + elif first_step_imputation == "median": + bz, n_steps, n_features = X_imputed.shape + X_imputed_reshaped = X_imputed.clone().reshape(-1, n_features) + median_values = torch.nanmedian(X_imputed_reshaped, dim=0) + for i, v in enumerate(median_values.values): + X_imputed[:, :, i] = torch.nan_to_num(X_imputed[:, :, i], nan=v) + if torch.isnan(X_imputed).any(): + X_imputed = torch.nan_to_num(X_imputed, nan=0) + + return X_imputed