Skip to content

Commit

Permalink
Merge pull request #354 from WenjieDu/(feat)csdi_forecasting
Browse files Browse the repository at this point in the history
Implement CSDI as a forecasting model
  • Loading branch information
WenjieDu authored Apr 18, 2024
2 parents ec8bee3 + fa32fb2 commit 88ec17b
Show file tree
Hide file tree
Showing 97 changed files with 2,494 additions and 1,036 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/testing_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ jobs:
- name: Test with pytest
run: |
python tests/global_test_config.py
rm -rf testing_results && rm -rf tests/__pycache__ && rm -rf tests/*/__pycache__
python tests/global_test_config.py
python -m pytest -rA tests/*/* -s -n auto --cov=pypots --dist=loadgroup --cov-config=.coveragerc
- name: Generate the LCOV report
Expand Down
7 changes: 5 additions & 2 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ build:
- pip install ./TSDB_repo && pip install ./PyGrinder_repo && pip install .

post_install:
# To fix the exception: This documentation is not using `furo.css` as the stylesheet.
# If you have set `html_style` in your conf.py file, remove it.
- pip install sphinx==7.2.6
# this docutils version fixes issue#102, put it in post_install to avoid being
# overwritten by other versions (like 0.19) while installing other packages
- pip install docutils==0.20
# this version fixes issue#102, put it in post_install to avoid being
# overwritten by other versions (like 0.19) while installing other packages
4 changes: 2 additions & 2 deletions docs/pypots.data.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
pypots.data package
===================

pypots.data.base
pypots.data.dataset
-----------------------

.. automodule:: pypots.data.base
.. automodule:: pypots.data.dataset
:members:
:undoc-members:
:show-inheritance:
Expand Down
20 changes: 10 additions & 10 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,13 @@ def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
file_type: str = "hdf5",
) -> None:
"""Train the classifier on the given data.
Parameters
----------
train_set : dict or str
train_set :
The dataset for model training, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
Expand All @@ -352,7 +352,7 @@ def fit(
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
val_set : dict or str
val_set :
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
Expand All @@ -361,7 +361,7 @@ def fit(
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type : str
file_type :
The type of the given file if train_set and val_set are path strings.
"""
Expand All @@ -371,13 +371,13 @@ def fit(
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> dict:
"""Make predictions for the input data with the trained model.
Parameters
----------
test_set : dict or str
test_set :
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
Expand All @@ -386,12 +386,12 @@ def predict(
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type : str
file_type :
The type of the given file if test_set is a path string.
Returns
-------
result_dict: dict
result_dict :
Prediction results in a Python Dictionary for the given samples.
It should be a dictionary including keys as 'imputation', 'classification', 'clustering', and 'forecasting'.
For sure, only the keys that relevant tasks are supported by the model will be returned.
Expand Down Expand Up @@ -512,14 +512,14 @@ def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
file_type: str = "hdf5",
) -> None:
raise NotImplementedError

@abstractmethod
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> dict:
raise NotImplementedError
24 changes: 12 additions & 12 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
file_type: str = "hdf5",
) -> None:
"""Train the classifier on the given data.
Expand Down Expand Up @@ -106,15 +106,15 @@ def fit(
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> dict:
raise NotImplementedError

@abstractmethod
def classify(
self,
X: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> np.ndarray:
"""Classify the input data with the trained model.
Expand Down Expand Up @@ -214,12 +214,12 @@ def __init__(
self.n_classes = n_classes

@abstractmethod
def _assemble_input_for_training(self, data) -> dict:
def _assemble_input_for_training(self, data: list) -> dict:
"""Assemble the given data into a dictionary for training input.
Parameters
----------
data : list,
data :
Input data from dataloader, should be list.
Returns
Expand All @@ -230,12 +230,12 @@ def _assemble_input_for_training(self, data) -> dict:
raise NotImplementedError

@abstractmethod
def _assemble_input_for_validating(self, data) -> dict:
def _assemble_input_for_validating(self, data: list) -> dict:
"""Assemble the given data into a dictionary for validating input.
Parameters
----------
data : list,
data :
Data output from dataloader, should be list.
Returns
Expand All @@ -246,7 +246,7 @@ def _assemble_input_for_validating(self, data) -> dict:
raise NotImplementedError

@abstractmethod
def _assemble_input_for_testing(self, data) -> dict:
def _assemble_input_for_testing(self, data: list) -> dict:
"""Assemble the given data into a dictionary for testing input.
Notes
Expand All @@ -259,7 +259,7 @@ def _assemble_input_for_testing(self, data) -> dict:
Parameters
----------
data : list,
data :
Data output from dataloader, should be list.
Returns
Expand Down Expand Up @@ -386,7 +386,7 @@ def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
file_type: str = "hdf5",
) -> None:
"""Train the classifier on the given data.
Expand Down Expand Up @@ -420,15 +420,15 @@ def fit(
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> dict:
raise NotImplementedError

@abstractmethod
def classify(
self,
X: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> np.ndarray:
"""Classify the input data with the trained model.
Expand Down
17 changes: 11 additions & 6 deletions pypots/classification/brits/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DatasetForBRITS(DatasetForBRITS_Imputation):
Parameters
----------
data : dict or str,
data :
The dataset for model input, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
Expand All @@ -26,7 +26,7 @@ class DatasetForBRITS(DatasetForBRITS_Imputation):
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
return_labels : bool, default = True,
return_y :
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
during training of classification models, the Dataset class will return labels in __getitem__() for model input.
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
Expand All @@ -35,14 +35,19 @@ class DatasetForBRITS(DatasetForBRITS_Imputation):
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
distinction.
file_type : str, default = "h5py"
file_type :
The type of the given file if train_set and val_set are path strings.
"""

def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
file_type: str = "h5py",
return_y: bool = True,
file_type: str = "hdf5",
):
super().__init__(data, False, return_labels, file_type)
super().__init__(
data=data,
return_X_ori=False,
return_y=return_y,
file_type=file_type,
)
8 changes: 4 additions & 4 deletions pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
file_type: str = "hdf5",
) -> None:
# Step 1: wrap the input data with classes Dataset and DataLoader
training_set = DatasetForBRITS(train_set, file_type=file_type)
Expand Down Expand Up @@ -239,10 +239,10 @@ def fit(
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> dict:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForBRITS(test_set, return_labels=False, file_type=file_type)
test_set = DatasetForBRITS(test_set, return_y=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
Expand All @@ -267,7 +267,7 @@ def predict(
def classify(
self,
X: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> np.ndarray:
"""Classify the input data with the trained model.
Expand Down
32 changes: 19 additions & 13 deletions pypots/classification/grud/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from ...data.base import BaseDataset
from ...data.dataset import BaseDataset
from ...data.utils import _parse_delta_torch
from ...imputation.locf import locf_torch

Expand All @@ -20,7 +20,7 @@ class DatasetForGRUD(BaseDataset):
Parameters
----------
data : dict or str,
data :
The dataset for model input, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
Expand All @@ -29,7 +29,7 @@ class DatasetForGRUD(BaseDataset):
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
return_labels : bool, default = True,
return_y :
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
during training of classification models, the Dataset class will return labels in __getitem__() for model input.
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
Expand All @@ -38,17 +38,23 @@ class DatasetForGRUD(BaseDataset):
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
distinction.
file_type : str, default = "h5py"
file_type :
The type of the given file if train_set and val_set are path strings.
"""

def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
file_type: str = "h5py",
return_y: bool = True,
file_type: str = "hdf5",
):
super().__init__(data, False, return_labels, file_type)
super().__init__(
data=data,
return_X_ori=False,
return_X_pred=False,
return_y=return_y,
file_type=file_type,
)
if not isinstance(self.data, str): # data from array
self.missing_mask = (~torch.isnan(self.X)).to(torch.float32)
self.X_filledLOCF = locf_torch(self.X)
Expand All @@ -63,12 +69,12 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
Parameters
----------
idx : int,
idx :
The index to fetch the specified sample.
Returns
-------
sample : list,
sample :
A list contains
index : int tensor,
Expand Down Expand Up @@ -98,7 +104,7 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
self.empirical_mean.to(torch.float32),
]

if self.y is not None and self.return_labels:
if self.return_y:
sample.append(self.y[idx].to(torch.long))

return sample
Expand All @@ -109,12 +115,12 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
Parameters
----------
idx : int,
idx :
The index of the sample to be return.
Returns
-------
sample : list,
sample :
The collated data sample, a list including all necessary sample info.
"""

Expand All @@ -140,7 +146,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
]

# if the dataset has labels and is for training, then fetch it from the file
if "y" in self.file_handle.keys() and self.return_labels:
if self.return_y:
sample.append(torch.tensor(self.file_handle["y"][idx], dtype=torch.long))

return sample
Loading

0 comments on commit 88ec17b

Please sign in to comment.