Skip to content

Commit

Permalink
Merge pull request #307 from WenjieDu/refactor_locf_funcs
Browse files Browse the repository at this point in the history
Refactor LOCF implementations
  • Loading branch information
WenjieDu authored Mar 12, 2024
2 parents 0893fd8 + 6881edc commit 2c8013a
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 125 deletions.
4 changes: 3 additions & 1 deletion .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ body:
required: true
attributes:
label: 4. Expected behavior
description: "A clear and concise description of what error you would expect to happen."
description: |
A clear and concise description of what error you would expect to happen.
Please provide the whole "Traceback" of the error message printed in the console.
8 changes: 3 additions & 5 deletions pypots/classification/grud/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ...data.base import BaseDataset
from ...data.utils import _parse_delta_torch
from ...imputation.locf import LOCF
from ...imputation.locf import locf_torch


class DatasetForGRUD(BaseDataset):
Expand Down Expand Up @@ -49,11 +49,9 @@ def __init__(
file_type: str = "h5py",
):
super().__init__(data, False, return_labels, file_type)
self.locf = LOCF()

if not isinstance(self.data, str): # data from array
self.missing_mask = (~torch.isnan(self.X)).to(torch.float32)
self.X_filledLOCF = self.locf._locf_torch(self.X)
self.X_filledLOCF = locf_torch(self.X)
self.X = torch.nan_to_num(self.X)
self.deltas = _parse_delta_torch(self.missing_mask)
self.empirical_mean = torch.sum(
Expand Down Expand Up @@ -125,7 +123,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:

X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32)
missing_mask = (~torch.isnan(X)).to(torch.float32)
X_filledLOCF = self.locf._locf_torch(X.unsqueeze(dim=0)).squeeze()
X_filledLOCF = locf_torch(X.unsqueeze(dim=0)).squeeze()
X = torch.nan_to_num(X)
deltas = _parse_delta_torch(missing_mask)
empirical_mean = torch.sum(missing_mask * X, dim=[0]) / torch.sum(
Expand Down
4 changes: 3 additions & 1 deletion pypots/imputation/locf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from .model import LOCF
from .model import LOCF, locf_numpy, locf_torch

__all__ = [
"LOCF",
"locf_numpy",
"locf_torch",
]
121 changes: 3 additions & 118 deletions pypots/imputation/locf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np
import torch

from .modules.core import locf_numpy, locf_torch
from ..base import BaseImputer
from ...utils.logging import logger

Expand Down Expand Up @@ -68,122 +69,6 @@ def fit(
"Please run func `predict()` directly."
)

def _locf_numpy(
self,
X: np.ndarray,
first_step_imputation: str = "backward",
) -> np.ndarray:
"""Numpy implementation of LOCF.
Parameters
----------
X : np.ndarray,
Time series containing missing values (NaN) to be imputed.
Returns
-------
X_imputed : array,
Imputed time series.
Notes
-----
This implementation gets inspired by the question on StackOverflow:
https://stackoverflow.com/questions/41190852/most-efficient-way-to-forward-fill-nan-values-in-numpy-array
"""
trans_X = X.transpose((0, 2, 1))
mask = np.isnan(trans_X)
n_samples, n_steps, n_features = mask.shape
idx = np.where(~mask, np.arange(n_features), 0)
idx = np.maximum.accumulate(idx, axis=2)

collector = []
for x, i in zip(trans_X, idx):
collector.append(x[np.arange(n_steps)[:, None], i])
X_imputed = np.asarray(collector)
X_imputed = X_imputed.transpose((0, 2, 1))

# If there are values still missing, they are missing at the beginning of the time-series sequence.
if np.isnan(X_imputed).any():
if first_step_imputation == "nan":
pass
elif first_step_imputation == "zero":
X_imputed = np.nan_to_num(X_imputed, nan=0)
elif first_step_imputation == "backward":
# imputed by last observation carried backward (LOCB)
X_imputed_transpose = np.copy(X_imputed)
X_imputed_transpose = np.flip(X_imputed_transpose, axis=1)
X_LOCB = self._locf_numpy(
X_imputed_transpose,
"zero",
)
X_imputed = np.flip(X_LOCB, axis=1)
elif first_step_imputation == "median":
bz, n_steps, n_features = X_imputed.shape
X_imputed_reshaped = np.copy(X_imputed).reshape(-1, n_features)
median_values = np.nanmedian(X_imputed_reshaped, axis=0)
for i, v in enumerate(median_values):
X_imputed[:, :, i] = np.nan_to_num(X_imputed[:, :, i], nan=v)
if np.isnan(X_imputed).any() and self.keep_trying:
X_imputed = np.nan_to_num(X_imputed, nan=0)

return X_imputed

def _locf_torch(
self,
X: torch.Tensor,
first_step_imputation: str = "backward",
) -> torch.Tensor:
"""Torch implementation of LOCF.
Parameters
----------
X : tensor,
Time series containing missing values (NaN) to be imputed.
Returns
-------
X_imputed : tensor,
Imputed time series.
"""
trans_X = X.permute((0, 2, 1))
mask = torch.isnan(trans_X)
n_samples, n_steps, n_features = mask.shape
idx = torch.where(~mask, torch.arange(n_features, device=mask.device), 0)
idx = np.maximum.accumulate(idx, axis=2)

collector = []
for x, i in zip(trans_X, idx):
collector.append(x[torch.arange(n_steps)[:, None], i])
X_imputed = torch.stack(collector)
X_imputed = X_imputed.permute((0, 2, 1))

# If there are values still missing, they are missing at the beginning of the time-series sequence.
if torch.isnan(X_imputed).any():
if first_step_imputation == "nan":
pass
elif first_step_imputation == "zero":
X_imputed = torch.nan_to_num(X_imputed, nan=0)
elif first_step_imputation == "backward":
# imputed by last observation carried backward (LOCB)
X_imputed_transpose = X_imputed.clone()
X_imputed_transpose = torch.flip(X_imputed_transpose, dims=[1])
X_LOCB = self._locf_torch(
X_imputed_transpose,
"zero",
)
X_imputed = torch.flip(X_LOCB, dims=[1])
elif first_step_imputation == "median":
bz, n_steps, n_features = X_imputed.shape
X_imputed_reshaped = X_imputed.clone().reshape(-1, n_features)
median_values = torch.nanmedian(X_imputed_reshaped, dim=0)
for i, v in enumerate(median_values.values):
X_imputed[:, :, i] = torch.nan_to_num(X_imputed[:, :, i], nan=v)
if torch.isnan(X_imputed).any() and self.keep_trying:
X_imputed = torch.nan_to_num(X_imputed, nan=0)

return X_imputed

def predict(
self,
test_set: Union[dict, str],
Expand Down Expand Up @@ -226,9 +111,9 @@ def predict(
X = np.asarray(X)

if isinstance(X, np.ndarray):
imputed_data = self._locf_numpy(X, self.first_step_imputation)
imputed_data = locf_numpy(X, self.first_step_imputation)
elif isinstance(X, torch.Tensor):
imputed_data = self._locf_torch(X, self.first_step_imputation)
imputed_data = locf_torch(X, self.first_step_imputation)
else:
raise TypeError(
"X must be type of list/np.ndarray/torch.Tensor, " f"but got {type(X)}"
Expand Down
6 changes: 6 additions & 0 deletions pypots/imputation/locf/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause
155 changes: 155 additions & 0 deletions pypots/imputation/locf/modules/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import numpy as np
import torch


def locf_numpy(
X: np.ndarray,
first_step_imputation: str = "backward",
) -> np.ndarray:
"""Numpy implementation of LOCF.
Parameters
----------
X : np.ndarray,
Time series containing missing values (NaN) to be imputed.
first_step_imputation : str, default='backward'
With LOCF, the observed values are carried forward to impute the missing ones. But if the first value
is missing, there is no value to carry forward. This parameter is used to determine the strategy to
impute the missing values at the beginning of the time-series sequence after LOCF is applied.
It can be one of ['backward', 'zero', 'median', 'nan'].
If 'nan', the missing values at the sequence beginning will be left as NaNs.
If 'zero', the missing values at the sequence beginning will be imputed with 0.
If 'backward', the missing values at the beginning of the time-series sequence will be imputed with the
first observed value in the sequence, i.e. the first observed value will be carried backward to impute
the missing values at the beginning of the sequence. This method is also known as NOCB (Next Observation
Carried Backward). If 'median', the missing values at the sequence beginning will be imputed with the overall
median values of features in the dataset.
If `first_step_imputation` is not "nan", if missing values still exist (this is usually caused by whole feature
missing) after applying `first_step_imputation`, they will be filled with 0.
Returns
-------
X_imputed : array,
Imputed time series.
Notes
-----
This implementation gets inspired by the question on StackOverflow:
https://stackoverflow.com/questions/41190852/most-efficient-way-to-forward-fill-nan-values-in-numpy-array
"""
trans_X = X.transpose((0, 2, 1))
mask = np.isnan(trans_X)
n_samples, n_steps, n_features = mask.shape
idx = np.where(~mask, np.arange(n_features), 0)
idx = np.maximum.accumulate(idx, axis=2)

collector = []
for x, i in zip(trans_X, idx):
collector.append(x[np.arange(n_steps)[:, None], i])
X_imputed = np.asarray(collector)
X_imputed = X_imputed.transpose((0, 2, 1))

# If there are values still missing, they are missing at the beginning of the time-series sequence.
if np.isnan(X_imputed).any():
if first_step_imputation == "nan":
pass
elif first_step_imputation == "zero":
X_imputed = np.nan_to_num(X_imputed, nan=0)
elif first_step_imputation == "backward":
# imputed by last observation carried backward (LOCB)
X_imputed_transpose = np.copy(X_imputed)
X_imputed_transpose = np.flip(X_imputed_transpose, axis=1)
X_LOCB = locf_numpy(
X_imputed_transpose,
"zero",
)
X_imputed = np.flip(X_LOCB, axis=1)
elif first_step_imputation == "median":
bz, n_steps, n_features = X_imputed.shape
X_imputed_reshaped = np.copy(X_imputed).reshape(-1, n_features)
median_values = np.nanmedian(X_imputed_reshaped, axis=0)
for i, v in enumerate(median_values):
X_imputed[:, :, i] = np.nan_to_num(X_imputed[:, :, i], nan=v)
if np.isnan(X_imputed).any():
X_imputed = np.nan_to_num(X_imputed, nan=0)

return X_imputed


def locf_torch(
X: torch.Tensor,
first_step_imputation: str = "backward",
) -> torch.Tensor:
"""Torch implementation of LOCF.
Parameters
----------
X : tensor,
Time series containing missing values (NaN) to be imputed.
first_step_imputation : str, default='backward'
With LOCF, the observed values are carried forward to impute the missing ones. But if the first value
is missing, there is no value to carry forward. This parameter is used to determine the strategy to
impute the missing values at the beginning of the time-series sequence after LOCF is applied.
It can be one of ['backward', 'zero', 'median', 'nan'].
If 'nan', the missing values at the sequence beginning will be left as NaNs.
If 'zero', the missing values at the sequence beginning will be imputed with 0.
If 'backward', the missing values at the beginning of the time-series sequence will be imputed with the
first observed value in the sequence, i.e. the first observed value will be carried backward to impute
the missing values at the beginning of the sequence. This method is also known as NOCB (Next Observation
Carried Backward). If 'median', the missing values at the sequence beginning will be imputed with the overall
median values of features in the dataset.
If `first_step_imputation` is not "nan", if missing values still exist (this is usually caused by whole feature
missing) after applying `first_step_imputation`, they will be filled with 0.
Returns
-------
X_imputed : tensor,
Imputed time series.
"""
trans_X = X.permute((0, 2, 1))
mask = torch.isnan(trans_X)
n_samples, n_steps, n_features = mask.shape
idx = torch.where(~mask, torch.arange(n_features, device=mask.device), 0)
idx = np.maximum.accumulate(idx, axis=2)

collector = []
for x, i in zip(trans_X, idx):
collector.append(x[torch.arange(n_steps)[:, None], i])
X_imputed = torch.stack(collector)
X_imputed = X_imputed.permute((0, 2, 1))

# If there are values still missing, they are missing at the beginning of the time-series sequence.
if torch.isnan(X_imputed).any():
if first_step_imputation == "nan":
pass
elif first_step_imputation == "zero":
X_imputed = torch.nan_to_num(X_imputed, nan=0)
elif first_step_imputation == "backward":
# imputed by last observation carried backward (LOCB)
X_imputed_transpose = X_imputed.clone()
X_imputed_transpose = torch.flip(X_imputed_transpose, dims=[1])
X_LOCB = locf_torch(
X_imputed_transpose,
"zero",
)
X_imputed = torch.flip(X_LOCB, dims=[1])
elif first_step_imputation == "median":
bz, n_steps, n_features = X_imputed.shape
X_imputed_reshaped = X_imputed.clone().reshape(-1, n_features)
median_values = torch.nanmedian(X_imputed_reshaped, dim=0)
for i, v in enumerate(median_values.values):
X_imputed[:, :, i] = torch.nan_to_num(X_imputed[:, :, i], nan=v)
if torch.isnan(X_imputed).any():
X_imputed = torch.nan_to_num(X_imputed, nan=0)

return X_imputed

0 comments on commit 2c8013a

Please sign in to comment.