Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the method predict() for all models #199

Merged
merged 9 commits into from
Oct 6, 2023
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

<h2 align="center">Welcome to PyPOTS</h2>

**<p align="center">A Python Toolbox for Data Mining on Partially-Observed Time Series</p>**
<p align="center"><i>a Python toolbox for data mining on Partially-Observed Time Series</i></p>

<p align="center">
<a href="https://docs.pypots.com/en/latest/install.html#reasons-of-version-limitations-on-dependencies">
Expand Down
25 changes: 2 additions & 23 deletions docs/pypots.forecasting.rst
Original file line number Diff line number Diff line change
@@ -1,31 +1,10 @@
pypots.forecasting package
==========================

Subpackages
-----------

.. toctree::
:maxdepth: 4

pypots.forecasting.bttf
pypots.forecasting.template

Submodules
----------

pypots.forecasting.base module
pypots.forecasting.bttf module
------------------------------

.. automodule:: pypots.forecasting.base
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

Module contents
---------------

.. automodule:: pypots.forecasting
.. automodule:: pypots.forecasting.bttf
:members:
:undoc-members:
:show-inheritance:
Expand Down
36 changes: 34 additions & 2 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ @article{du2023SAITS
pages = {119619},
year = {2023},
issn = {0957-4174},
doi = {https://doi.org/10.1016/j.eswa.2023.119619},
url = {https://www.sciencedirect.com/science/article/pii/S0957417423001203},
doi = {10.1016/j.eswa.2023.119619},
url = {https://arxiv.org/abs/2202.08516},
author = {Wenjie Du and David Cote and Yan Liu},
}
@article{fortuin2020GPVAEDeep,
Expand Down Expand Up @@ -418,3 +418,35 @@ @inproceedings{reddi2018OnTheConvergence
year={2018},
url={https://openreview.net/forum?id=ryQu7f-RZ},
}

@article{hubert1985,
title={Comparing partitions},
author={Hubert, Lawrence and Arabie, Phipps},
journal={Journal of classification},
volume={2},
pages={193--218},
year={1985},
publisher={Springer}
}

@article{steinley2004,
title={Properties of the hubert-arable adjusted rand index},
author={Steinley, Douglas},
journal={Psychological methods},
volume={9},
number={3},
pages={386},
year={2004},
publisher={American Psychological Association}
}

@article{calinski1974,
title={A dendrite method for cluster analysis},
author={Cali{\'n}ski, Tadeusz and Harabasz, Jerzy},
journal={Communications in Statistics-theory and Methods},
volume={3},
number={1},
pages={1--27},
year={1974},
publisher={Taylor \& Francis}
}
4 changes: 2 additions & 2 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ dependencies:
#- conda-forge::python
#- conda-forge::pip
#- conda-forge::scipy
#- conda-forge::numpy >=1.23.3 # numpy should , otherwise may encounter "number not available" when torch>1.11
#- conda-forge::scikit-learn >=0.24.1
#- conda-forge::numpy
#- conda-forge::scikit-learn
#- conda-forge::pandas <2.0.0
#- conda-forge::h5py
#- conda-forge::tensorboard
Expand Down
2 changes: 1 addition & 1 deletion pypots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
__version__ = "0.1.2"
__version__ = "0.1.3"


__all__ = [
Expand Down
141 changes: 113 additions & 28 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import os
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from typing import Optional, Union

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

Expand Down Expand Up @@ -209,6 +211,33 @@ def _save_log_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None
if ("loss" in item_name) or ("error" in item_name):
self.summary_writer.add_scalar(f"{stage}/{item_name}", loss.sum(), step)

def _auto_save_model_if_necessary(
self,
training_finished: bool = True,
saving_name: str = None,
):
"""Automatically save the current model into a file if in need.

Parameters
----------
training_finished :
Whether the training is already finished when invoke this function.
The saving_strategy "better" only works when training_finished is False.
The saving_strategy "best" only works when training_finished is True.

saving_name :
The file name of the saved model.

"""
if self.saving_path is not None and self.model_saving_strategy is not None:
name = self.__class__.__name__ if saving_name is None else saving_name
if not training_finished and self.model_saving_strategy == "better":
self.save_model(self.saving_path, name)
elif training_finished and self.model_saving_strategy == "best":
self.save_model(self.saving_path, name)
else:
return

def save_model(
self,
saving_dir: str,
Expand Down Expand Up @@ -258,33 +287,6 @@ def save_model(
f'Failed to save the model to "{saving_path}" because of the below error! \n{e}'
)

def _auto_save_model_if_necessary(
self,
training_finished: bool = True,
saving_name: str = None,
):
"""Automatically save the current model into a file if in need.

Parameters
----------
training_finished :
Whether the training is already finished when invoke this function.
The saving_strategy "better" only works when training_finished is False.
The saving_strategy "best" only works when training_finished is True.

saving_name :
The file name of the saved model.

"""
if self.saving_path is not None and self.model_saving_strategy is not None:
name = self.__class__.__name__ if saving_name is None else saving_name
if not training_finished and self.model_saving_strategy == "better":
self.save_model(self.saving_path, name)
elif training_finished and self.model_saving_strategy == "best":
self.save_model(self.saving_path, name)
else:
return

def load_model(self, model_path: str) -> None:
"""Load the saved model from a disk file.

Expand Down Expand Up @@ -317,6 +319,72 @@ def load_model(self, model_path: str) -> None:
raise e
logger.info(f"Model loaded successfully from {model_path}.")

@abstractmethod
def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
) -> None:
"""Train the classifier on the given data.

Parameters
----------
train_set : dict or str
The dataset for model training, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for training, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.

val_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.

file_type : str
The type of the given file if train_set and val_set are path strings.

"""
raise NotImplementedError

@abstractmethod
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
"""Make predictions for the input data with the trained model.

Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.

file_type : str
The type of the given file if test_set is a path string.

Returns
-------
result_dict: dict
Prediction results in a Python Dictionary for the given samples.
It should be a dictionary including keys as 'imputation', 'classification', 'clustering', and 'forecasting'.
For sure, only the keys that relevant tasks are supported by the model will be returned.
"""
raise NotImplementedError


class BaseNNModel(BaseModel):
"""The abstract class for all neural-network models.
Expand Down Expand Up @@ -400,7 +468,7 @@ def __init__(
else:
assert (
patience <= epochs
), f"patience must be smaller than epoches which is {epochs}, but got patience={patience}"
), f"patience must be smaller than epochs which is {epochs}, but got patience={patience}"

# training hype-parameters
self.batch_size = batch_size
Expand All @@ -421,3 +489,20 @@ def _print_model_size(self) -> None:
logger.info(
f"Model initialized successfully with the number of trainable parameters: {num_params:,}"
)

@abstractmethod
def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
) -> None:
raise NotImplementedError

@abstractmethod
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
raise NotImplementedError
12 changes: 12 additions & 0 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ def fit(
"""
raise NotImplementedError

@abstractmethod
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
raise NotImplementedError

@abstractmethod
def classify(
self,
Expand All @@ -117,6 +125,8 @@ def classify(
array-like, shape [n_samples],
Classification results of the given samples.
"""
# this is for old API compatibility, will be removed in the future.
# Please implement predict() instead.
raise NotImplementedError


Expand Down Expand Up @@ -402,4 +412,6 @@ def classify(
array-like, shape [n_samples],
Classification results of the given samples.
"""
# this is for old API compatibility, will be removed in the future.
# Please implement predict() instead.
raise NotImplementedError
Loading
Loading