Skip to content

Commit

Permalink
Merge pull request #176 from AugustJW/main
Browse files Browse the repository at this point in the history
add models GP-VAE/USGAN
  • Loading branch information
WenjieDu authored Sep 21, 2023
2 parents a329f79 + 0a6b37a commit 9bfffa1
Show file tree
Hide file tree
Showing 18 changed files with 1,608 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/about_us.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ PyPOTS exists thanks to all the nice people (sorted by contribution time) who co

.. raw:: html

<object data="https://pypots.com/figs/PyPOTS_contributors.svg">
<object data="https://pypots.com/figs/pypots_logos/PyPOTS_contributors.svg">
</object>
5 changes: 3 additions & 2 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def _setup_device(self, device: Union[None, str, torch.device, list]):
self.device = device
elif isinstance(device, list):
if len(device) == 0:
raise ValueError("The list of devices should have at least 1 device, but got 0.")
raise ValueError(
"The list of devices should have at least 1 device, but got 0."
)
elif len(device) == 1:
return self._setup_device(device[0])
# parallely training on multiple CUDA devices
Expand Down Expand Up @@ -176,7 +178,6 @@ def _send_data_to_given_device(self, data):
if isinstance(self.device, torch.device): # single device
data = map(lambda x: x.to(self.device), data)
else: # parallely training on multiple devices

# randomly choose one device to balance the workload
# device = np.random.choice(self.device)

Expand Down
1 change: 0 additions & 1 deletion pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def _train_model(
training_loader: DataLoader,
val_loader: DataLoader = None,
) -> None:

# each training starts from the very beginning, so reset the loss and model dict here
self.best_loss = float("inf")
self.best_model_dict = None
Expand Down
1 change: 0 additions & 1 deletion pypots/classification/raindrop/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def forward(
edge_attr: OptTensor = None,
return_attention_weights=None,
) -> Tuple[torch.Tensor, Any]:

r"""
Args:
return_attention_weights (bool, optional): If set to :obj:`True`,
Expand Down
1 change: 0 additions & 1 deletion pypots/clustering/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ def _train_model(
training_loader: DataLoader,
val_loader: DataLoader = None,
) -> None:

"""
Parameters
Expand Down
1 change: 0 additions & 1 deletion pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def __init__(
saving_path: Optional[str] = None,
model_saving_strategy: Optional[str] = "best",
):

super().__init__(
n_clusters,
batch_size,
Expand Down
2 changes: 0 additions & 2 deletions pypots/clustering/vader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ def forward(
) = self.get_results(X, missing_mask)

if not training and not pretrain:

results = {
"mu_tilde": mu_tilde,
"mu": mu_c,
Expand Down Expand Up @@ -403,7 +402,6 @@ def _train_model(
training_loader: DataLoader,
val_loader: DataLoader = None,
) -> None:

# each training starts from the very beginning, so reset the loss and model dict here
self.best_loss = float("inf")
self.best_model_dict = None
Expand Down
1 change: 0 additions & 1 deletion pypots/forecasting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def _train_model(
training_loader: DataLoader,
val_loader: DataLoader = None,
) -> None:

# each training starts from the very beginning, so reset the loss and model dict here
self.best_loss = float("inf")
self.best_model_dict = None
Expand Down
6 changes: 5 additions & 1 deletion pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@
# License: GPL-v3

from .brits import BRITS
from .gpvae import GPVAE
from .locf import LOCF
from .mrnn import MRNN
from .saits import SAITS
from .transformer import Transformer
from .mrnn import MRNN
from .usgan import USGAN

__all__ = [
"SAITS",
"Transformer",
"BRITS",
"MRNN",
"LOCF",
"GPVAE",
"USGAN",
]
12 changes: 12 additions & 0 deletions pypots/imputation/gpvae/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
The package of the partially-observed time-series imputation method GP-VAE.
"""

# Created by Jun Wang <[email protected]>
# License: GLP-v3

from .model import GPVAE

__all__ = [
"GPVAE",
]
133 changes: 133 additions & 0 deletions pypots/imputation/gpvae/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""
Dataset class for model GP-VAE.
"""

# Created by Jun Wang <[email protected]> and Wenjie Du <[email protected]>
# License: GLP-v3

from typing import Union, Iterable

import torch

from ...data.base import BaseDataset
from ...data.utils import torch_parse_delta


class DatasetForGPVAE(BaseDataset):
"""Dataset class for GP-VAE.
Parameters
----------
data : dict or str,
The dataset for model input, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for input, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
return_labels : bool, default = True,
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
during training of classification models, the Dataset class will return labels in __getitem__() for model input.
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5
files, they already have both X and y saved. But we don't read labels from the file for validating and testing
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
distinction.
file_type : str, default = "h5py"
The type of the given file if train_set and val_set are path strings.
"""

def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
file_type: str = "h5py",
):
super().__init__(data, return_labels, file_type)

if not isinstance(self.data, str):
# calculate all delta here.
missing_mask = (~torch.isnan(self.X)).type(torch.float32)
X = torch.nan_to_num(self.X)

self.processed_data = {
"X": X,
"missing_mask": missing_mask,
}

def _fetch_data_from_array(self, idx: int) -> Iterable:
"""Fetch data from self.X if it is given.
Parameters
----------
idx : int,
The index of the sample to be return.
Returns
-------
sample : list,
A list contains
index : int tensor,
The index of the sample.
X : tensor,
The feature vector for model input.
missing_mask : tensor,
The mask indicates all missing values in X.
delta : tensor,
The delta matrix contains time gaps of missing values.
label (optional) : tensor,
The target label of the time-series sample.
"""
sample = [
torch.tensor(idx),
# for forward
self.processed_data["X"][idx].to(torch.float32),
self.processed_data["missing_mask"][idx].to(torch.float32),
]

if self.y is not None and self.return_labels:
sample.append(self.y[idx].to(torch.long))

return sample

def _fetch_data_from_file(self, idx: int) -> Iterable:
"""Fetch data with the lazy-loading strategy, i.e. only loading data from the file while requesting for samples.
Here the opened file handle doesn't load the entire dataset into RAM but only load the currently accessed slice.
Parameters
----------
idx : int,
The index of the sample to be return.
Returns
-------
sample : list,
The collated data sample, a list including all necessary sample info.
"""

if self.file_handle is None:
self.file_handle = self._open_file_handle()

X = torch.from_numpy(self.file_handle["X"][idx])
missing_mask = (~torch.isnan(X)).to(torch.float32)
X = torch.nan_to_num(X)

sample = [
torch.tensor(idx),
X,
missing_mask,
]

# if the dataset has labels and is for training, then fetch it from the file
if "y" in self.file_handle.keys() and self.return_labels:
sample.append(torch.tensor(self.file_handle["y"][idx], dtype=torch.long))

return sample
Loading

0 comments on commit 9bfffa1

Please sign in to comment.