Skip to content

Commit

Permalink
Merge pull request #316 from WenjieDu/(feat)add_mean_and_median
Browse files Browse the repository at this point in the history
Add mean and median as imputation methods
  • Loading branch information
WenjieDu authored Mar 19, 2024
2 parents 6975b28 + 7777efa commit 529aed5
Show file tree
Hide file tree
Showing 8 changed files with 472 additions and 2 deletions.
2 changes: 2 additions & 0 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def _setup_device(self, device: Union[None, str, torch.device, list]) -> None:
def _setup_path(self, saving_path) -> None:
MODEL_NO_NEED_TO_SAVE = [
"LOCF",
"Median",
"Mean",
]
# if the model is no need to save (e.g. LOCF), then skip the following steps
if self.__class__.__name__ in MODEL_NO_NEED_TO_SAVE:
Expand Down
13 changes: 11 additions & 2 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,33 @@
# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

# neural network imputation methods
from .brits import BRITS
from .csdi import CSDI
from .gpvae import GPVAE
from .locf import LOCF
from .mrnn import MRNN
from .saits import SAITS
from .timesnet import TimesNet
from .transformer import Transformer
from .usgan import USGAN

# naive imputation methods
from .locf import LOCF
from .mean import Mean
from .median import Median

__all__ = [
# neural network imputation methods
"SAITS",
"Transformer",
"TimesNet",
"BRITS",
"MRNN",
"LOCF",
"GPVAE",
"USGAN",
"CSDI",
# naive imputation methods
"LOCF",
"Mean",
"Median",
]
12 changes: 12 additions & 0 deletions pypots/imputation/mean/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
The package of the partially-observed time-series imputation method Median.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from .model import Mean

__all__ = [
"Mean",
]
143 changes: 143 additions & 0 deletions pypots/imputation/mean/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""
The implementation of Mean value imputation.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import warnings
from typing import Union, Optional

import h5py
import numpy as np
import torch

from ..base import BaseImputer
from ...utils.logging import logger


class Mean(BaseImputer):
"""Mean value imputation method."""

def __init__(
self,
):
super().__init__()

def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
) -> None:
"""Train the imputer on the given data.
Warnings
--------
Mean imputation class does not need to run fit().
Please run func ``predict()`` directly.
"""
warnings.warn(
"Mean imputation class has no parameter to train. "
"Please run func `predict()` directly."
)

def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
"""Make predictions for the input data with the trained model.
Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type : str
The type of the given file if test_set is a path string.
Returns
-------
result_dict: dict
Prediction results in a Python Dictionary for the given samples.
It should be a dictionary including keys as 'imputation', 'classification', 'clustering', and 'forecasting'.
For sure, only the keys that relevant tasks are supported by the model will be returned.
"""
if isinstance(test_set, str):
with h5py.File(test_set, "r") as f:
X = f["X"][:]
else:
X = test_set["X"]

assert len(X.shape) == 3, (
f"Input X should have 3 dimensions [n_samples, n_steps, n_features], "
f"but the actual shape of X: {X.shape}"
)
if isinstance(X, list):
X = np.asarray(X)

n_samples, n_steps, n_features = X.shape

if isinstance(X, np.ndarray):
X_imputed_reshaped = np.copy(X).reshape(-1, n_features)
mean_values = np.nanmean(X_imputed_reshaped, axis=0)
for i, v in enumerate(mean_values):
X_imputed_reshaped[:, i] = np.nan_to_num(
X_imputed_reshaped[:, i], nan=v
)
imputed_data = X_imputed_reshaped.reshape(n_samples, n_steps, n_features)
elif isinstance(X, torch.Tensor):
X_imputed_reshaped = torch.clone(X).reshape(-1, n_features)
mean_values = torch.nanmean(X_imputed_reshaped, dim=0).numpy()
for i, v in enumerate(mean_values):
X_imputed_reshaped[:, i] = torch.nan_to_num(
X_imputed_reshaped[:, i], nan=v
)
imputed_data = X_imputed_reshaped.reshape(n_samples, n_steps, n_features)
else:
raise ValueError()

result_dict = {
"imputation": imputed_data,
}
return result_dict

def impute(
self,
X: Union[dict, str],
file_type="h5py",
) -> np.ndarray:
"""Impute missing values in the given data with the trained model.
Warnings
--------
The method impute is deprecated. Please use `predict()` instead.
Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
n_features], or a path string locating a data file, e.g. h5 file.
file_type :
The type of the given file if X is a path string.
Returns
-------
array-like, shape [n_samples, sequence length (time steps), n_features],
Imputed data.
"""
logger.warning(
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
)
results_dict = self.predict(X, file_type=file_type)
return results_dict["imputation"]
12 changes: 12 additions & 0 deletions pypots/imputation/median/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
The package of the partially-observed time-series imputation method Median.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from .model import Median

__all__ = [
"Median",
]
144 changes: 144 additions & 0 deletions pypots/imputation/median/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
The implementation of Median value imputation.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import warnings
from typing import Union, Optional

import h5py
import numpy as np
import torch

from ..base import BaseImputer
from ...utils.logging import logger


class Median(BaseImputer):
"""Median value imputation method."""

def __init__(
self,
):
super().__init__()

def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
) -> None:
"""Train the imputer on the given data.
Warnings
--------
Median imputation class does not need to run fit().
Please run func ``predict()`` directly.
"""
warnings.warn(
"Median imputation class has no parameter to train. "
"Please run func `predict()` directly."
)

def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
"""Make predictions for the input data with the trained model.
Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type : str
The type of the given file if test_set is a path string.
Returns
-------
result_dict: dict
Prediction results in a Python Dictionary for the given samples.
It should be a dictionary including keys as 'imputation', 'classification', 'clustering', and 'forecasting'.
For sure, only the keys that relevant tasks are supported by the model will be returned.
"""
if isinstance(test_set, str):
with h5py.File(test_set, "r") as f:
X = f["X"][:]
else:
X = test_set["X"]

assert len(X.shape) == 3, (
f"Input X should have 3 dimensions [n_samples, n_steps, n_features], "
f"but the actual shape of X: {X.shape}"
)
if isinstance(X, list):
X = np.asarray(X)

n_samples, n_steps, n_features = X.shape

if isinstance(X, np.ndarray):
X_imputed_reshaped = np.copy(X).reshape(-1, n_features)
median_values = np.nanmedian(X_imputed_reshaped, axis=0)
for i, v in enumerate(median_values):
X_imputed_reshaped[:, i] = np.nan_to_num(
X_imputed_reshaped[:, i], nan=v
)
imputed_data = X_imputed_reshaped.reshape(n_samples, n_steps, n_features)
elif isinstance(X, torch.Tensor):
X_imputed_reshaped = torch.clone(X).reshape(-1, n_features)
median_values = torch.nanmedian(X_imputed_reshaped, dim=0).values.numpy()
for i, v in enumerate(median_values):
X_imputed_reshaped[:, i] = torch.nan_to_num(
X_imputed_reshaped[:, i], nan=v
)
imputed_data = X_imputed_reshaped.reshape(n_samples, n_steps, n_features)

else:
raise ValueError()

result_dict = {
"imputation": imputed_data,
}
return result_dict

def impute(
self,
X: Union[dict, str],
file_type="h5py",
) -> np.ndarray:
"""Impute missing values in the given data with the trained model.
Warnings
--------
The method impute is deprecated. Please use `predict()` instead.
Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
n_features], or a path string locating a data file, e.g. h5 file.
file_type :
The type of the given file if X is a path string.
Returns
-------
array-like, shape [n_samples, sequence length (time steps), n_features],
Imputed data.
"""
logger.warning(
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
)
results_dict = self.predict(X, file_type=file_type)
return results_dict["imputation"]
Loading

0 comments on commit 529aed5

Please sign in to comment.