Skip to content

Commit

Permalink
Merge pull request #317 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Add mean and median as imputation methods, and update docs
  • Loading branch information
WenjieDu authored Mar 19, 2024
2 parents 03e618d + 13f2caf commit a0470b2
Show file tree
Hide file tree
Showing 11 changed files with 495 additions and 17 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

<p align="center">
<a href="https://docs.pypots.com/en/latest/install.html#reasons-of-version-limitations-on-dependencies">
<img alt="Python version" src="https://img.shields.io/badge/Python->=v3.7-E97040?logo=python&logoColor=white">
<img alt="Python version" src="https://img.shields.io/badge/Python-v3.7+-E97040?logo=python&logoColor=white">
</a>
<a href="https://github.com/WenjieDu/PyPOTS">
<img alt="powered by Pytorch" src="https://img.shields.io/badge/PyTorch-❤️-F8C6B5?logo=pytorch&logoColor=white">
Expand Down Expand Up @@ -144,7 +144,7 @@ Alternatively, you can install from the latest source code with the latest featu

## ❖ Usage
Besides [BrewPOTS](https://github.com/WenjieDu/BrewPOTS), you can also find a simple and quick-start tutorial notebook
on Google Colab with [this link](https://colab.research.google.com/drive/1HEFjylEy05-r47jRy0H9jiS_WhD0UWmQ?usp=sharing).
on Google Colab <a href="https://colab.research.google.com/drive/1HEFjylEy05-r47jRy0H9jiS_WhD0UWmQ"><img src="https://img.shields.io/badge/GoogleColab-PyPOTS_Tutorials-F9AB00?logo=googlecolab&logoColor=white" alt="Colab tutorials" align="center"/></a>.
If you have further questions, please refer to PyPOTS documentation [docs.pypots.com](https://docs.pypots.com).
You can also [raise an issue](https://github.com/WenjieDu/PyPOTS/issues) or [ask in our community](#-community).

Expand Down Expand Up @@ -265,7 +265,8 @@ By committing your code, you'll
Take a look at our [inclusion criteria](https://docs.pypots.com/en/latest/faq.html#inclusion-criteria).
You can utilize the `template` folder in each task package (e.g.
[pypots/imputation/template](https://github.com/WenjieDu/PyPOTS/tree/main/pypots/imputation/template)) to quickly start;
2. be listed as one of [PyPOTS contributors](https://pypots.com/about/#all-contributors);
2. become one of [PyPOTS contributors](https://github.com/WenjieDu/PyPOTS/graphs/contributors) and
be listed as a volunteer developer [on the PyPOTS website](https://pypots.com/about/#volunteer-developers);
3. get mentioned in our [release notes](https://github.com/WenjieDu/PyPOTS/releases);
You can also contribute to PyPOTS by simply staring🌟 this repo to help more people notice it.
Expand Down
4 changes: 2 additions & 2 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Quick-start Examples
We put some examples here to help our users to get started quickly.

Please refer to `BrewPOTS <https://github.com/WenjieDu/BrewPOTS>`_ for detailed PyPOTS tutorials.
You can also find a simple and quick-start tutorial notebook on Google Colab with
`this link <https://colab.research.google.com/drive/1HEFjylEy05-r47jRy0H9jiS_WhD0UWmQ?usp=sharing>`_.
You can also find a simple and quick-start tutorial notebook on Google Colab

.. raw:: html

<a href="https://colab.research.google.com/drive/1HEFjylEy05-r47jRy0H9jiS_WhD0UWmQ" target="_blank"><img src="https://img.shields.io/badge/GoogleColab-PyPOTS_Tutorials-F9AB00?logo=googlecolab&logoColor=white"></a>
<br clear="right">


Expand Down
27 changes: 17 additions & 10 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Welcome to PyPOTS docs!

**A Python Toolbox for Data Mining on Partially-Observed Time Series**

.. image:: https://img.shields.io/badge/Python->=v3.7-E97040?logo=python&logoColor=white
.. image:: https://img.shields.io/badge/Python-v3.7+-E97040?logo=python&logoColor=white
:alt: Python version
:target: https://docs.pypots.com/en/latest/install.html#reasons-of-version-limitations-on-dependencies

Expand Down Expand Up @@ -88,12 +88,12 @@ if it helps with your research. This really means a lot to our open-source resea

The rest of this readme file is organized as follows:
`❖ PyPOTS Ecosystem <#id1>`_,
`❖ Installation <#id2>`_,
`❖ Usage <#id4>`_,
`❖ Available Algorithms <#id6>`_,
`❖ Citing PyPOTS <#id19>`_,
`❖ Contribution <#id20>`_,
`❖ Community <#id21>`_.
`❖ Installation <#id3>`_,
`❖ Usage <#id5>`_,
`❖ Available Algorithms <#id7>`_,
`❖ Citing PyPOTS <#id22>`_,
`❖ Contribution <#id23>`_,
`❖ Community <#id24>`_.


❖ PyPOTS Ecosystem
Expand Down Expand Up @@ -136,7 +136,13 @@ Considering the future workload, PyPOTS tutorials is released in a single repo,
and you can find them in `BrewPOTS <https://github.com/WenjieDu/BrewPOTS>`_.
Take a look at it now, and learn how to brew your POTS datasets!

☕️ Welcome to the universe of PyPOTS. Enjoy it and have fun!
**☕️ Welcome to the universe of PyPOTS. Enjoy it and have fun!**

.. image:: https://pypots.com/figs/pypots_logos/Ecosystem/PyPOTS_Ecosystem_Pipeline.png
:width: 95%
:alt: BrewPOTS logo
:align: center
:target: https://pypots.com/ecosystem/


❖ Installation
Expand All @@ -149,7 +155,7 @@ Refer to the page `Installation <install.html>`_ to see different ways of instal
❖ Usage
^^^^^^^^
Besides `BrewPOTS <https://github.com/WenjieDu/BrewPOTS>`_, you can also find a simple and quick-start tutorial notebook
on Google Colab with `this link <https://colab.research.google.com/drive/1HEFjylEy05-r47jRy0H9jiS_WhD0UWmQ?usp=sharing>`_.
on Google Colab with `this link <https://colab.research.google.com/drive/1HEFjylEy05-r47jRy0H9jiS_WhD0UWmQ>`_.
You can also `raise an issue <https://github.com/WenjieDu/PyPOTS/issues>`_ or `ask in our community <#id21>`_.

Additionally, we present you a usage example of imputing missing values in time series with PyPOTS in
Expand Down Expand Up @@ -227,7 +233,8 @@ By committing your code, you'll
Take a look at our `inclusion criteria <https://docs.pypots.com/en/latest/faq.html#inclusion-criteria>`_.
You can utilize the ``template`` folder in each task package (e.g.
`pypots/imputation/template <https://github.com/WenjieDu/PyPOTS/tree/main/pypots/imputation/template>`_) to quickly start;
2. be listed as one of `PyPOTS contributors <https://github.com/WenjieDu/PyPOTS/graphs/contributors>`_:
2. become one of `PyPOTS contributors <https://github.com/WenjieDu/PyPOTS/graphs/contributors>`_ and
be listed as a volunteer developer `on the PyPOTS website <https://pypots.com/about/#volunteer-developers>`_;
3. get mentioned in our `release notes <https://github.com/WenjieDu/PyPOTS/releases>`_;

You can also contribute to PyPOTS by simply staring🌟 this repo to help more people notice it.
Expand Down
2 changes: 2 additions & 0 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def _setup_device(self, device: Union[None, str, torch.device, list]) -> None:
def _setup_path(self, saving_path) -> None:
MODEL_NO_NEED_TO_SAVE = [
"LOCF",
"Median",
"Mean",
]
# if the model is no need to save (e.g. LOCF), then skip the following steps
if self.__class__.__name__ in MODEL_NO_NEED_TO_SAVE:
Expand Down
13 changes: 11 additions & 2 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,33 @@
# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

# neural network imputation methods
from .brits import BRITS
from .csdi import CSDI
from .gpvae import GPVAE
from .locf import LOCF
from .mrnn import MRNN
from .saits import SAITS
from .timesnet import TimesNet
from .transformer import Transformer
from .usgan import USGAN

# naive imputation methods
from .locf import LOCF
from .mean import Mean
from .median import Median

__all__ = [
# neural network imputation methods
"SAITS",
"Transformer",
"TimesNet",
"BRITS",
"MRNN",
"LOCF",
"GPVAE",
"USGAN",
"CSDI",
# naive imputation methods
"LOCF",
"Mean",
"Median",
]
12 changes: 12 additions & 0 deletions pypots/imputation/mean/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
The package of the partially-observed time-series imputation method Median.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from .model import Mean

__all__ = [
"Mean",
]
143 changes: 143 additions & 0 deletions pypots/imputation/mean/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""
The implementation of Mean value imputation.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import warnings
from typing import Union, Optional

import h5py
import numpy as np
import torch

from ..base import BaseImputer
from ...utils.logging import logger


class Mean(BaseImputer):
"""Mean value imputation method."""

def __init__(
self,
):
super().__init__()

def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
) -> None:
"""Train the imputer on the given data.
Warnings
--------
Mean imputation class does not need to run fit().
Please run func ``predict()`` directly.
"""
warnings.warn(
"Mean imputation class has no parameter to train. "
"Please run func `predict()` directly."
)

def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
"""Make predictions for the input data with the trained model.
Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type : str
The type of the given file if test_set is a path string.
Returns
-------
result_dict: dict
Prediction results in a Python Dictionary for the given samples.
It should be a dictionary including keys as 'imputation', 'classification', 'clustering', and 'forecasting'.
For sure, only the keys that relevant tasks are supported by the model will be returned.
"""
if isinstance(test_set, str):
with h5py.File(test_set, "r") as f:
X = f["X"][:]
else:
X = test_set["X"]

assert len(X.shape) == 3, (
f"Input X should have 3 dimensions [n_samples, n_steps, n_features], "
f"but the actual shape of X: {X.shape}"
)
if isinstance(X, list):
X = np.asarray(X)

n_samples, n_steps, n_features = X.shape

if isinstance(X, np.ndarray):
X_imputed_reshaped = np.copy(X).reshape(-1, n_features)
mean_values = np.nanmean(X_imputed_reshaped, axis=0)
for i, v in enumerate(mean_values):
X_imputed_reshaped[:, i] = np.nan_to_num(
X_imputed_reshaped[:, i], nan=v
)
imputed_data = X_imputed_reshaped.reshape(n_samples, n_steps, n_features)
elif isinstance(X, torch.Tensor):
X_imputed_reshaped = torch.clone(X).reshape(-1, n_features)
mean_values = torch.nanmean(X_imputed_reshaped, dim=0).numpy()
for i, v in enumerate(mean_values):
X_imputed_reshaped[:, i] = torch.nan_to_num(
X_imputed_reshaped[:, i], nan=v
)
imputed_data = X_imputed_reshaped.reshape(n_samples, n_steps, n_features)
else:
raise ValueError()

result_dict = {
"imputation": imputed_data,
}
return result_dict

def impute(
self,
X: Union[dict, str],
file_type="h5py",
) -> np.ndarray:
"""Impute missing values in the given data with the trained model.
Warnings
--------
The method impute is deprecated. Please use `predict()` instead.
Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
n_features], or a path string locating a data file, e.g. h5 file.
file_type :
The type of the given file if X is a path string.
Returns
-------
array-like, shape [n_samples, sequence length (time steps), n_features],
Imputed data.
"""
logger.warning(
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
)
results_dict = self.predict(X, file_type=file_type)
return results_dict["imputation"]
12 changes: 12 additions & 0 deletions pypots/imputation/median/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
The package of the partially-observed time-series imputation method Median.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from .model import Median

__all__ = [
"Median",
]
Loading

0 comments on commit a0470b2

Please sign in to comment.