Skip to content

Commit

Permalink
Merge pull request #308 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Refactor LOCF, fix Raindrop on multiple cuda devices, and update docs
  • Loading branch information
WenjieDu authored Mar 13, 2024
2 parents 3ff6887 + 2c8013a commit 3f120e8
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 141 deletions.
4 changes: 3 additions & 1 deletion .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ body:
required: true
attributes:
label: 4. Expected behavior
description: "A clear and concise description of what error you would expect to happen."
description: |
A clear and concise description of what error you would expect to happen.
Please provide the whole "Traceback" of the error message printed in the console.
17 changes: 12 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<img src="https://pypots.com/figs/pypots_logos/PyPOTS/logo_FFBG.svg" width="200" align="right">
</a>

<h2 align="center">Welcome to PyPOTS</h2>
<h3 align="center">Welcome to PyPOTS</h3>

<p align="center"><i>a Python toolbox for machine learning on Partially-Observed Time Series</i></p>

Expand Down Expand Up @@ -213,10 +213,6 @@ This functionality is implemented with the [Microsoft NNI](https://github.com/mi


## ❖ Citing PyPOTS
**[Updates in Jun 2023]** 🎉A short version of the PyPOTS paper is accepted by the 9th SIGKDD international workshop on
Mining and Learning from Time Series ([MiLeTS'23](https://kdd-milets.github.io/milets2023/))).
Besides, PyPOTS has been included as a [PyTorch Ecosystem](https://pytorch.org/ecosystem/) project.
The paper introducing PyPOTS is available on arXiv at [this URL](https://arxiv.org/abs/2305.18811),
and we are pursuing to publish it in prestigious academic venues, e.g. JMLR (track for
[Machine Learning Open Source Software](https://www.jmlr.org/mloss/)). If you use PyPOTS in your work,
Expand All @@ -243,6 +239,17 @@ doi={10.48550/arXiv.2305.18811},
> arXiv, abs/2305.18811.https://arxiv.org/abs/2305.18811


> [!TIP]
> **[Updates in Feb 2024]** 😎 Our survey paper [Deep Learning for Multivariate Time Series Imputation: A Survey](https://arxiv.org/abs/2402.04059) has been released on arXiv.
The code is open source in the GitHub repo [Awesome_Imputation](https://github.com/WenjieDu/Awesome_Imputation).
We comprehensively review the literature of the state-of-the-art deep-learning imputation methods for time series,
provide a taxonomy for them, and discuss the challenges and future directions in this field.
>
> **[Updates in Jun 2023]** 🎉 A short version of the PyPOTS paper is accepted by the 9th SIGKDD international workshop on
Mining and Learning from Time Series ([MiLeTS'23](https://kdd-milets.github.io/milets2023/))).
Besides, PyPOTS has been included as a [PyTorch Ecosystem](https://pytorch.org/ecosystem/) project.
## ❖ Contribution
You're very welcome to contribute to this exciting project!

Expand Down
14 changes: 7 additions & 7 deletions docs/about_us.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@ Core Development Team

Wenjie Du
**********
- Initialized the project in March 2022
- Founded the organization in March 2022
- `GitHub (WenjieDu) <https://github.com/WenjieDu>`_
- `LinkedIn (Wenjie Du) <https://www.linkedin.com/in/wenjie-du>`_

Maciej Skrabski
***************
- Joined in May 2023
- `GitHub (MaciejSkrabski) <https://github.com/MaciejSkrabski>`_
- `LinkedIn (Maciej Skrabski) <https://www.linkedin.com/in/maciej-skrabski-75595525a>`_

Jun Wang
********
- Joined in August 2023
- `GitHub (AugustJW) <https://github.com/AugustJW>`_
- `LinkedIn (Jun Wang) <https://www.linkedin.com/in/wang-jun-35323b193>`_

Linglong Qian
********
- Joined in February 2024
- `GitHub (LinglongQian) <https://github.com/LinglongQian>`_
- `LinkedIn (Linglong Qian) <https://www.linkedin.com/in/linglongqian>`_


All Contributors
""""""""""""""""
Expand Down
4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,10 @@ Your star is your recognition to PyPOTS, and it matters!

The lists of PyPOTS stargazers and forkers are shown below, and we're so proud to have more and more awesome users, as well as more bright ✨stars:

.. image:: http://reporoster.com/stars/dark/WenjieDu/PyPOTS
.. image:: https://bytecrank.com/nastyox/reporoster/php/stargazersSVG.php?theme=dark&user=WenjieDu&repo=PyPOTS
:alt: PyPOTS stargazers
:target: https://github.com/WenjieDu/PyPOTS/stargazers
.. image:: http://reporoster.com/forks/dark/WenjieDu/PyPOTS
.. image:: https://bytecrank.com/nastyox/reporoster/php/forkersSVG.php?theme=dark&user=WenjieDu&repo=PyPOTS
:alt: PyPOTS forkers
:target: https://github.com/WenjieDu/PyPOTS/network/members

Expand Down
8 changes: 3 additions & 5 deletions pypots/classification/grud/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ...data.base import BaseDataset
from ...data.utils import _parse_delta_torch
from ...imputation.locf import LOCF
from ...imputation.locf import locf_torch


class DatasetForGRUD(BaseDataset):
Expand Down Expand Up @@ -49,11 +49,9 @@ def __init__(
file_type: str = "h5py",
):
super().__init__(data, False, return_labels, file_type)
self.locf = LOCF()

if not isinstance(self.data, str): # data from array
self.missing_mask = (~torch.isnan(self.X)).to(torch.float32)
self.X_filledLOCF = self.locf._locf_torch(self.X)
self.X_filledLOCF = locf_torch(self.X)
self.X = torch.nan_to_num(self.X)
self.deltas = _parse_delta_torch(self.missing_mask)
self.empirical_mean = torch.sum(
Expand Down Expand Up @@ -125,7 +123,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:

X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32)
missing_mask = (~torch.isnan(X)).to(torch.float32)
X_filledLOCF = self.locf._locf_torch(X.unsqueeze(dim=0)).squeeze()
X_filledLOCF = locf_torch(X.unsqueeze(dim=0)).squeeze()
X = torch.nan_to_num(X)
deltas = _parse_delta_torch(missing_mask)
empirical_mean = torch.sum(missing_mask * X, dim=[0]) / torch.sum(
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/gpvae/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ def forward(self, x):
if len(x.shape) > 2:
shape = list(np.arange(len(x.shape)))
new_shape = [0, shape[-1]] + shape[1:-1]
out = super(CustomConv1d, self).forward(x.permute(*new_shape))
out = super().forward(x.permute(*new_shape))
shape = list(np.arange(len(out.shape)))
new_shape = [0, shape[-1]] + shape[1:-1]
if self.kernel_size[0] % 2 == 0:
out = F.pad(out, (0, -1), "constant", 0)
return out.permute(new_shape)

return super(CustomConv1d, self).forward(x)
return super().forward(x)


def make_cnn(input_size, output_size, hidden_sizes, kernel_size=3):
Expand Down
4 changes: 3 additions & 1 deletion pypots/imputation/locf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from .model import LOCF
from .model import LOCF, locf_numpy, locf_torch

__all__ = [
"LOCF",
"locf_numpy",
"locf_torch",
]
121 changes: 3 additions & 118 deletions pypots/imputation/locf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np
import torch

from .modules.core import locf_numpy, locf_torch
from ..base import BaseImputer
from ...utils.logging import logger

Expand Down Expand Up @@ -68,122 +69,6 @@ def fit(
"Please run func `predict()` directly."
)

def _locf_numpy(
self,
X: np.ndarray,
first_step_imputation: str = "backward",
) -> np.ndarray:
"""Numpy implementation of LOCF.
Parameters
----------
X : np.ndarray,
Time series containing missing values (NaN) to be imputed.
Returns
-------
X_imputed : array,
Imputed time series.
Notes
-----
This implementation gets inspired by the question on StackOverflow:
https://stackoverflow.com/questions/41190852/most-efficient-way-to-forward-fill-nan-values-in-numpy-array
"""
trans_X = X.transpose((0, 2, 1))
mask = np.isnan(trans_X)
n_samples, n_steps, n_features = mask.shape
idx = np.where(~mask, np.arange(n_features), 0)
idx = np.maximum.accumulate(idx, axis=2)

collector = []
for x, i in zip(trans_X, idx):
collector.append(x[np.arange(n_steps)[:, None], i])
X_imputed = np.asarray(collector)
X_imputed = X_imputed.transpose((0, 2, 1))

# If there are values still missing, they are missing at the beginning of the time-series sequence.
if np.isnan(X_imputed).any():
if first_step_imputation == "nan":
pass
elif first_step_imputation == "zero":
X_imputed = np.nan_to_num(X_imputed, nan=0)
elif first_step_imputation == "backward":
# imputed by last observation carried backward (LOCB)
X_imputed_transpose = np.copy(X_imputed)
X_imputed_transpose = np.flip(X_imputed_transpose, axis=1)
X_LOCB = self._locf_numpy(
X_imputed_transpose,
"zero",
)
X_imputed = np.flip(X_LOCB, axis=1)
elif first_step_imputation == "median":
bz, n_steps, n_features = X_imputed.shape
X_imputed_reshaped = np.copy(X_imputed).reshape(-1, n_features)
median_values = np.nanmedian(X_imputed_reshaped, axis=0)
for i, v in enumerate(median_values):
X_imputed[:, :, i] = np.nan_to_num(X_imputed[:, :, i], nan=v)
if np.isnan(X_imputed).any() and self.keep_trying:
X_imputed = np.nan_to_num(X_imputed, nan=0)

return X_imputed

def _locf_torch(
self,
X: torch.Tensor,
first_step_imputation: str = "backward",
) -> torch.Tensor:
"""Torch implementation of LOCF.
Parameters
----------
X : tensor,
Time series containing missing values (NaN) to be imputed.
Returns
-------
X_imputed : tensor,
Imputed time series.
"""
trans_X = X.permute((0, 2, 1))
mask = torch.isnan(trans_X)
n_samples, n_steps, n_features = mask.shape
idx = torch.where(~mask, torch.arange(n_features, device=mask.device), 0)
idx = np.maximum.accumulate(idx, axis=2)

collector = []
for x, i in zip(trans_X, idx):
collector.append(x[torch.arange(n_steps)[:, None], i])
X_imputed = torch.stack(collector)
X_imputed = X_imputed.permute((0, 2, 1))

# If there are values still missing, they are missing at the beginning of the time-series sequence.
if torch.isnan(X_imputed).any():
if first_step_imputation == "nan":
pass
elif first_step_imputation == "zero":
X_imputed = torch.nan_to_num(X_imputed, nan=0)
elif first_step_imputation == "backward":
# imputed by last observation carried backward (LOCB)
X_imputed_transpose = X_imputed.clone()
X_imputed_transpose = torch.flip(X_imputed_transpose, dims=[1])
X_LOCB = self._locf_torch(
X_imputed_transpose,
"zero",
)
X_imputed = torch.flip(X_LOCB, dims=[1])
elif first_step_imputation == "median":
bz, n_steps, n_features = X_imputed.shape
X_imputed_reshaped = X_imputed.clone().reshape(-1, n_features)
median_values = torch.nanmedian(X_imputed_reshaped, dim=0)
for i, v in enumerate(median_values.values):
X_imputed[:, :, i] = torch.nan_to_num(X_imputed[:, :, i], nan=v)
if torch.isnan(X_imputed).any() and self.keep_trying:
X_imputed = torch.nan_to_num(X_imputed, nan=0)

return X_imputed

def predict(
self,
test_set: Union[dict, str],
Expand Down Expand Up @@ -226,9 +111,9 @@ def predict(
X = np.asarray(X)

if isinstance(X, np.ndarray):
imputed_data = self._locf_numpy(X, self.first_step_imputation)
imputed_data = locf_numpy(X, self.first_step_imputation)
elif isinstance(X, torch.Tensor):
imputed_data = self._locf_torch(X, self.first_step_imputation)
imputed_data = locf_torch(X, self.first_step_imputation)
else:
raise TypeError(
"X must be type of list/np.ndarray/torch.Tensor, " f"but got {type(X)}"
Expand Down
6 changes: 6 additions & 0 deletions pypots/imputation/locf/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause
Loading

0 comments on commit 3f120e8

Please sign in to comment.