Skip to content

Commit

Permalink
Merge pull request #556 from AugustJW/main
Browse files Browse the repository at this point in the history
Add TRMF imputation method
  • Loading branch information
WenjieDu authored Dec 16, 2024
2 parents a27468c + 07dd05d commit e330678
Show file tree
Hide file tree
Showing 3 changed files with 347 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .timemixer import TimeMixer
from .moderntcn import ModernTCN
from .segrnn import SegRNN
from .trmf import TRMF

# naive imputation methods
from .locf import LOCF
Expand Down Expand Up @@ -89,4 +90,5 @@
"TEFN",
"CSAI",
"SegRNN",
"TRMF"
]
19 changes: 19 additions & 0 deletions pypots/imputation/trmf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
The package including the modules of TRMF.
Refer to the paper
Yu, H. F., Rao, N., & Dhillon, I. S.
Temporal regularized matrix factorization for high-dimensional time series prediction.
In Advances in neural information processing systems 2016.
<https://www.cs.utexas.edu/~rofuyu/papers/tr-mf-nips.pdf>`_
"""

# Created by Jun Wang <[email protected]> and Wenjie Du <[email protected]>
# License: BSD-3-Clause

from .model import TRMF

__all__ = [
"TRMF",
]
326 changes: 326 additions & 0 deletions pypots/imputation/trmf/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
"""
The implementation of TRMF for the partially-observed time-series imputation task, which is mainly based on the implementation of TRMF in https://github.com/SemenovAlex/trmf/blob/master/trmf.py
"""

# Created by Jun Wang <[email protected]> and Wenjie Du <[email protected]>
# License: BSD-3-Clause


import warnings
from typing import Union, Optional

import h5py
import numpy as np
import torch

from ..base import BaseImputer


class TRMF:
"""Temporal Regularized Matrix Factorization (TRMF) imputation method.
Parameters
----------
lags : array-like, shape (n_lags,)
Set of lag indices to use in model.
K : int
Length of latent embedding dimension
lambda_f : float
Regularization parameter used for matrix F.
lambda_x : float
Regularization parameter used for matrix X.
lambda_w : float
Regularization parameter used for matrix W.
alpha : float
Regularization parameter used for make the sum of lag coefficient close to 1.
That helps to avoid big deviations when forecasting.
eta : float
Regularization parameter used for X when undercovering autoregressive dependencies.
max_iter : int
Number of iterations of updating matrices F, X and W.
F_step : float
Step of gradient descent when updating matrix F.
X_step : float
Step of gradient descent when updating matrix X.
W_step : float
Step of gradient descent when updating matrix W.
Attributes
----------
F : ndarray, shape (n_timeseries, K)
Latent embedding of timeseries.
X : ndarray, shape (K, n_timepoints)
Latent embedding of timepoints.
W : ndarray, shape (K, n_lags)
Matrix of autoregressive coefficients.
"""

def __init__(self, lags, L, K, lambda_f, lambda_x, lambda_w, alpha, eta, max_iter=1000,
F_step=0.0001, X_step=0.0001, W_step=0.0001
):
super(TRMF, self).__init__()

self.lags = lags
self.L = L
self.K = K
self.lambda_f = lambda_f
self.lambda_x = lambda_x
self.lambda_w = lambda_w
self.alpha = alpha
self.eta = eta
self.max_iter = max_iter
self.F_step = F_step
self.X_step = X_step
self.W_step = W_step

self.W = None
self.F = None
self.X = None


def fit(self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "hdf5",
resume: bool = False,
) -> None:
"""Fit the TRMF model according to the given training data.
Model fits through sequential updating three matrices:
- matrix self.F;
- matrix self.X;
- matrix self.W.
Each matrix updated with gradient descent.
Parameters
----------
train_set : ndarray, shape (n_timeseries, n_timepoints)
Training data.
val_set : ndarray, shape (n_timeseries, n_timepoints)
Validation data.
file_type :
The type of the given file if train_set and val_set are path strings.
resume : bool
Used to continue fitting.
Returns
-------
self : object
Returns self.
"""

if isinstance(train_set, str):
with h5py.File(train_set, "r") as f:
X = f["X"][:]
else:
X = train_set["X"]

assert len(X.shape) == 2, (
f"Input X should have 2 dimensions [n_samples, n_features], "
f"but the actual shape of X: {X.shape}"
)
if isinstance(X, list):
X = np.asarray(X)

if not resume:
self.Y = X.copy()
mask = np.array((~np.isnan(self.Y)).astype(int))
self.mask = mask
self.Y[self.mask == 0] = 0.
self.N, self.T = self.Y.shape
self.W = np.random.randn(self.K, self.L) / self.L
self.F = np.random.randn(self.N, self.K)
self.X = np.random.randn(self.K, self.T)

for _ in range(self.max_iter):
self._update_F(step=self.F_step)
self._update_X(step=self.X_step)
self._update_W(step=self.W_step)


def impute_missings(self):
"""Impute each missing element in timeseries.
Model uses matrix X and F to get all missing elements.
Parameters
----------
Returns
-------
data : ndarray, shape (n_timeseries, T)
Predictions.
"""
data = self.Y
data[self.mask == 0] = np.dot(self.F, self.X)[self.mask == 0]
result_dict = {
"imputation": data,
}
return result_dict


def impute(
self,
) -> np.ndarray:
result_dict = self.impute_missings()
return result_dict["imputation"]


def _update_F(self, step, n_iter=1):
"""Gradient descent of matrix F.
n_iter steps of gradient descent of matrix F.
Parameters
----------
step : float
Step of gradient descent when updating matrix.
n_iter : int
Number of gradient steps to be made.
Returns
-------
self : objects
Returns self.
"""

for _ in range(n_iter):
self.F -= step * self._grad_F()


def _update_X(self, step, n_iter=1):
"""Gradient descent of matrix X.
n_iter steps of gradient descent of matrix X.
Parameters
----------
step : float
Step of gradient descent when updating matrix.
n_iter : int
Number of gradient steps to be made.
Returns
-------
self : objects
Returns self.
"""

for _ in range(n_iter):
self.X -= step * self._grad_X()


def _update_W(self, step, n_iter=1):
"""Gradient descent of matrix W.
n_iter steps of gradient descent of matrix W.
Parameters
----------
step : float
Step of gradient descent when updating matrix.
n_iter : int
Number of gradient steps to be made.
Returns
-------
self : objects
Returns self.
"""

for _ in range(n_iter):
self.W -= step * self._grad_W()


def _grad_F(self):
"""Gradient of matrix F.
Evaluating gradient of matrix F.
Parameters
----------
Returns
-------
self : objects
Returns self.
"""

return - 2 * np.dot((self.Y - np.dot(self.F, self.X)) * self.mask, self.X.T) + 2 * self.lambda_f * self.F


def _grad_X(self):
"""Gradient of matrix X.
Evaluating gradient of matrix X.
Parameters
----------
Returns
-------
self : objects
Returns self.
"""

for l in range(self.L):
lag = self.lags[l]
W_l = self.W[:, l].repeat(self.T, axis=0).reshape(self.K, self.T)
X_l = self.X * W_l
z_1 = self.X - np.roll(X_l, lag, axis=1)
z_1[:, :max(self.lags)] = 0.
z_2 = - (np.roll(self.X, -lag, axis=1) - X_l) * W_l
z_2[:, -lag:] = 0.

grad_T_x = z_1 + z_2
return - 2 * np.dot(self.F.T, self.mask * (self.Y - np.dot(self.F, self.X))) + self.lambda_x * grad_T_x + self.eta * self.X


def _grad_W(self):
"""Gradient of matrix W.
Evaluating gradient of matrix W.
Parameters
----------
Returns
-------
self : objects
Returns self.
"""

grad = np.zeros((self.K, self.L))
for l in range(self.L):
lag = self.lags[l]
W_l = self.W[:, l].repeat(self.T, axis=0).reshape(self.K, self.T)
X_l = self.X * W_l
z_1 = self.X - np.roll(X_l, lag, axis=1)
z_1[:, :max(self.lags)] = 0.
z_2 = - (z_1 * np.roll(self.X, lag, axis=1)).sum(axis=1)
grad[:, l] = z_2
return grad + self.W * 2 * self.lambda_w / self.lambda_x -\
self.alpha * 2 * (1 - self.W.sum(axis=1)).repeat(self.L).reshape(self.W.shape)

0 comments on commit e330678

Please sign in to comment.