Skip to content

Commit

Permalink
refactor: removing modules from tsdb and pygrinder, now directly usin…
Browse files Browse the repository at this point in the history
…g them;
  • Loading branch information
WenjieDu committed Oct 13, 2023
1 parent bad8ab5 commit dc3976b
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 134 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ We present you a usage example of imputing missing values in time series with Py
``` python
import numpy as np
from sklearn.preprocessing import StandardScaler
from pypots.data import load_specific_dataset, mcar, masked_fill
from pygrinder import mcar, masked_fill
from pypots.data import load_specific_dataset
from pypots.imputation import SAITS
from pypots.utils.metrics import cal_mae
# Data preprocessing. Tedious, but PyPOTS can help.
Expand Down
3 changes: 2 additions & 1 deletion docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ You can also find a simple and quick-start tutorial notebook on Google Colab wit
import numpy as np
from sklearn.preprocessing import StandardScaler
from pypots.data import load_specific_dataset, mcar, masked_fill
from pygrinder import mcar, masked_fill
from pypots.data import load_specific_dataset
from pypots.imputation import SAITS
from pypots.utils.metrics import cal_mae
Expand Down
6 changes: 3 additions & 3 deletions pypots/classification/grud/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch

from ...data.base import BaseDataset
from ...data.utils import torch_parse_delta
from ...data.utils import _parse_delta_torch
from ...imputation.locf import LOCF


Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(
self.missing_mask = (~torch.isnan(self.X)).to(torch.float32)
self.X_filledLOCF = self.locf._locf_torch(self.X)
self.X = torch.nan_to_num(self.X)
self.deltas = torch_parse_delta(self.missing_mask)
self.deltas = _parse_delta_torch(self.missing_mask)
self.empirical_mean = torch.sum(
self.missing_mask * self.X, dim=[0, 1]
) / torch.sum(self.missing_mask, dim=[0, 1])
Expand Down Expand Up @@ -127,7 +127,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
missing_mask = (~torch.isnan(X)).to(torch.float32)
X_filledLOCF = self.locf._locf_torch(X.unsqueeze(dim=0)).squeeze()
X = torch.nan_to_num(X)
deltas = torch_parse_delta(missing_mask)
deltas = _parse_delta_torch(missing_mask)
empirical_mean = torch.sum(missing_mask * X, dim=[0]) / torch.sum(
missing_mask, dim=[0]
)
Expand Down
13 changes: 3 additions & 10 deletions pypots/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,8 @@
list_supported_datasets,
load_specific_dataset,
)
from .utils import (
masked_fill,
mcar,
pickle_load,
pickle_dump,
)
from .saving import save_dict_into_h5
from .utils import parse_delta, sliding_window

__all__ = [
# datasets
Expand All @@ -38,10 +33,8 @@
"list_supported_datasets",
"load_specific_dataset",
# utils
"masked_fill",
"mcar",
"pickle_load",
"pickle_dump",
"parse_delta",
"sliding_window",
# saving
"save_dict_into_h5",
]
2 changes: 1 addition & 1 deletion pypots/data/generating.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

import numpy as np
import torch
from pygrinder import mcar, masked_fill
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state

from .load_specific_datasets import load_specific_dataset
from .utils import mcar, masked_fill


def gene_complete_random_walk(
Expand Down
176 changes: 63 additions & 113 deletions pypots/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,127 +5,33 @@
# Created by Wenjie Du <[email protected]>
# License: GLP-v3


from typing import Union

import numpy as np
import pygrinder
import torch
from tsdb import (
pickle_load as _pickle_load,
pickle_dump as _pickle_dump,
)

pickle_load = _pickle_load
pickle_dump = _pickle_dump


def cal_missing_rate(X: Union[np.ndarray, torch.Tensor, list]) -> float:
"""Calculate the missing rate of the given data.
Parameters
----------
X :
The data to calculate missing rate.

Returns
-------
missing_rate :
The missing rate of the given data.
"""
missing_rate = pygrinder.cal_missing_rate(X)
return missing_rate


def masked_fill(
X: Union[np.ndarray, torch.Tensor, list],
mask: Union[np.ndarray, torch.Tensor, list],
value: float,
) -> Union[np.ndarray, torch.Tensor]:
"""Fill the masked values in ``X`` according to ``mask`` with the given ``value``.

Parameters
----------
X :
The data to be filled.
mask :
The mask for filling the given data.
value :
The value to fill the masked values.
Returns
-------
filled_X :
The filled data.
"""
filled_X = pygrinder.masked_fill(X, mask, value)
return filled_X


def mcar(
X: Union[np.ndarray, torch.Tensor, list],
p: float,
nan: float = 0,
) -> Union[np.ndarray, torch.Tensor]:
"""Create completely random missing values (MCAR case).
def _parse_delta_torch(missing_mask: torch.Tensor) -> torch.Tensor:
"""Generate the time-gap matrix (i.e. the delta metrix) from the missing mask.
Please refer to :cite:`che2018GRUD` for its math definition.
Parameters
----------
X : array,
Data vector. If X has any missing values, they should be numpy.nan.
p : float, in (0,1),
The probability that values may be masked as missing completely at random.
Note that the values are randomly selected no matter if they are originally missing or observed.
If the selected values are originally missing, they will be kept as missing.
If the selected values are originally observed, they will be masked as missing.
Therefore, if the given X already contains missing data, the final missing rate in the output X could be
in range [original_missing_rate, original_missing_rate+rate], but not strictly equal to
`original_missing_rate+rate`. Because the selected values to be artificially masked out may be originally
missing, and the masking operation on the values will do nothing.
nan : int/float, optional, default=0
Value used to fill NaN values.
missing_mask : shape of [n_steps, n_features] or [n_samples, n_steps, n_features]
Binary masks indicate missing data (0 means missing values, 1 means observed values).
Returns
-------
X_intact : array,
Original data with missing values (nan) filled with given parameter `nan`, with observed values intact.
X_intact is for loss calculation in the masked imputation task.
X : array,
Original X with artificial missing values. X is for model input.
Both originally-missing and artificially-missing values are filled with given parameter `nan`.
delta :
The delta matrix indicates the time gaps between observed values.
With the same shape of missing_mask.
missing_mask : array,
The mask indicates all missing values in X.
In it, 1 indicates observed values, and 0 indicates missing values.
indicating_mask : array,
The mask indicates the artificially-missing values in X, namely missing parts different from X_intact.
In it, 1 indicates artificially missing values, and other values are indicated as 0.
"""
X = pygrinder.mcar(X, p, nan)
return X


def torch_parse_delta(missing_mask: torch.Tensor) -> torch.Tensor:
"""Generate time-gap (delta) matrix from missing masks.
Please refer to :cite:`che2018GRUD` for its math definition.
Parameters
References
----------
missing_mask :
Binary masks indicate missing values. Shape of [n_steps, n_features] or [n_samples, n_steps, n_features]
.. [1] `Che, Zhengping, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu.
"Recurrent neural networks for multivariate time series with missing values."
Scientific reports 8, no. 1 (2018): 6085.
<https://www.nature.com/articles/s41598-018-24271-9.pdf>`_
Returns
-------
delta
Delta matrix indicates time gaps of missing values.
"""

def cal_delta_for_single_sample(mask: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -156,18 +62,28 @@ def cal_delta_for_single_sample(mask: torch.Tensor) -> torch.Tensor:
return delta


def numpy_parse_delta(missing_mask: np.ndarray) -> np.ndarray:
"""Generate time-gap (delta) matrix from missing masks. Please refer to :cite:`che2018GRUD` for its math definition.
def _parse_delta_numpy(missing_mask: np.ndarray) -> np.ndarray:
"""Generate the time-gap matrix (i.e. the delta metrix) from the missing mask.
Please refer to :cite:`che2018GRUD` for its math definition.
Parameters
----------
missing_mask :
Binary masks indicate missing values. Shape of [n_steps, n_features] or [n_samples, n_steps, n_features].
missing_mask : shape of [n_steps, n_features] or [n_samples, n_steps, n_features]
Binary masks indicate missing data (0 means missing values, 1 means observed values).
Returns
-------
delta
Delta matrix indicates time gaps of missing values.
delta :
The delta matrix indicates the time gaps between observed values.
With the same shape of missing_mask.
References
----------
.. [1] `Che, Zhengping, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu.
"Recurrent neural networks for multivariate time series with missing values."
Scientific reports 8, no. 1 (2018): 6085.
<https://www.nature.com/articles/s41598-018-24271-9.pdf>`_
"""

def cal_delta_for_single_sample(mask: np.ndarray) -> np.ndarray:
Expand All @@ -194,6 +110,40 @@ def cal_delta_for_single_sample(mask: np.ndarray) -> np.ndarray:
return delta


def parse_delta(
missing_mask: Union[np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
"""Generate the time-gap matrix (i.e. the delta metrix) from the missing mask.
Please refer to :cite:`che2018GRUD` for its math definition.
Parameters
----------
missing_mask : shape of [n_steps, n_features] or [n_samples, n_steps, n_features]
Binary masks indicate missing data (0 means missing values, 1 means observed values).
Returns
-------
delta :
The delta matrix indicates the time gaps between observed values.
With the same shape of missing_mask.
References
----------
.. [1] `Che, Zhengping, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu.
"Recurrent neural networks for multivariate time series with missing values."
Scientific reports 8, no. 1 (2018): 6085.
<https://www.nature.com/articles/s41598-018-24271-9.pdf>`_
"""
if isinstance(missing_mask, np.ndarray):
delta = _parse_delta_numpy(missing_mask)
elif isinstance(missing_mask, torch.Tensor):
delta = _parse_delta_torch(missing_mask)
else:
raise RuntimeError
return delta


def sliding_window(time_series, window_len, sliding_len=None):
"""Generate time series samples with sliding window method, truncating windows from time-series data
with a given sequence length.
Expand Down
10 changes: 5 additions & 5 deletions pypots/imputation/brits/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch

from ...data.base import BaseDataset
from ...data.utils import torch_parse_delta
from ...data.utils import _parse_delta_torch


class DatasetForBRITS(BaseDataset):
Expand Down Expand Up @@ -52,10 +52,10 @@ def __init__(
# calculate all delta here.
forward_missing_mask = (~torch.isnan(self.X)).type(torch.float32)
forward_X = torch.nan_to_num(self.X)
forward_delta = torch_parse_delta(forward_missing_mask)
forward_delta = _parse_delta_torch(forward_missing_mask)
backward_X = torch.flip(forward_X, dims=[1])
backward_missing_mask = torch.flip(forward_missing_mask, dims=[1])
backward_delta = torch_parse_delta(backward_missing_mask)
backward_delta = _parse_delta_torch(backward_missing_mask)

self.processed_data = {
"forward": {
Expand Down Expand Up @@ -140,14 +140,14 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
forward = {
"X": X,
"missing_mask": missing_mask,
"deltas": torch_parse_delta(missing_mask),
"deltas": _parse_delta_torch(missing_mask),
}

backward = {
"X": torch.flip(forward["X"], dims=[0]),
"missing_mask": torch.flip(forward["missing_mask"], dims=[0]),
}
backward["deltas"] = torch_parse_delta(backward["missing_mask"])
backward["deltas"] = _parse_delta_torch(backward["missing_mask"])

sample = [
torch.tensor(idx),
Expand Down

0 comments on commit dc3976b

Please sign in to comment.