From 07dd05de65e77da099007f1721c61dde64978e05 Mon Sep 17 00:00:00 2001
From: Jun Wang <jwangfx@connect.ust.hk>
Date: Sun, 8 Dec 2024 22:09:00 +0800
Subject: [PATCH] Add TRMF for imputation

---
 pypots/imputation/__init__.py      |   2 +
 pypots/imputation/trmf/__init__.py |  19 ++
 pypots/imputation/trmf/model.py    | 326 +++++++++++++++++++++++++++++
 3 files changed, 347 insertions(+)
 create mode 100644 pypots/imputation/trmf/__init__.py
 create mode 100644 pypots/imputation/trmf/model.py

diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py
index 9ce1d867..c232941b 100644
--- a/pypots/imputation/__init__.py
+++ b/pypots/imputation/__init__.py
@@ -39,6 +39,7 @@
 from .timemixer import TimeMixer
 from .moderntcn import ModernTCN
 from .segrnn import SegRNN
+from .trmf import TRMF
 
 # naive imputation methods
 from .locf import LOCF
@@ -89,4 +90,5 @@
     "TEFN",
     "CSAI",
     "SegRNN",
+    "TRMF"
 ]
diff --git a/pypots/imputation/trmf/__init__.py b/pypots/imputation/trmf/__init__.py
new file mode 100644
index 00000000..9ab66658
--- /dev/null
+++ b/pypots/imputation/trmf/__init__.py
@@ -0,0 +1,19 @@
+"""
+The package including the modules of TRMF.
+
+Refer to the paper
+Yu, H. F., Rao, N., & Dhillon, I. S.
+Temporal regularized matrix factorization for high-dimensional time series prediction. 
+In Advances in neural information processing systems 2016.
+<https://www.cs.utexas.edu/~rofuyu/papers/tr-mf-nips.pdf>`_
+
+"""
+
+# Created by Jun Wang <jwangfx@connect.ust.hk> and Wenjie Du <wenjay.du@gmail.com>
+# License: BSD-3-Clause
+
+from .model import TRMF
+
+__all__ = [
+    "TRMF",
+]
diff --git a/pypots/imputation/trmf/model.py b/pypots/imputation/trmf/model.py
new file mode 100644
index 00000000..727de273
--- /dev/null
+++ b/pypots/imputation/trmf/model.py
@@ -0,0 +1,326 @@
+"""
+The implementation of TRMF for the partially-observed time-series imputation task, which is mainly based on the implementation of TRMF in https://github.com/SemenovAlex/trmf/blob/master/trmf.py
+
+"""
+
+# Created by Jun Wang <jwangfx@connect.ust.hk> and Wenjie Du <wenjay.du@gmail.com>
+# License: BSD-3-Clause
+
+
+import warnings
+from typing import Union, Optional
+
+import h5py
+import numpy as np
+import torch
+
+from ..base import BaseImputer
+
+
+class TRMF:
+    """Temporal Regularized Matrix Factorization (TRMF) imputation method.
+
+    Parameters
+    ----------
+    
+    lags : array-like, shape (n_lags,)
+        Set of lag indices to use in model.
+    
+    K : int
+        Length of latent embedding dimension
+    
+    lambda_f : float
+        Regularization parameter used for matrix F.
+    
+    lambda_x : float
+        Regularization parameter used for matrix X.
+    
+    lambda_w : float
+        Regularization parameter used for matrix W.
+
+    alpha : float
+        Regularization parameter used for make the sum of lag coefficient close to 1.
+        That helps to avoid big deviations when forecasting.
+    
+    eta : float
+        Regularization parameter used for X when undercovering autoregressive dependencies.
+
+    max_iter : int
+        Number of iterations of updating matrices F, X and W.
+
+    F_step : float
+        Step of gradient descent when updating matrix F.
+
+    X_step : float
+        Step of gradient descent when updating matrix X.
+
+    W_step : float
+        Step of gradient descent when updating matrix W.
+
+
+    Attributes
+    ----------
+
+    F : ndarray, shape (n_timeseries, K)
+        Latent embedding of timeseries.
+
+    X : ndarray, shape (K, n_timepoints)
+        Latent embedding of timepoints.
+
+    W : ndarray, shape (K, n_lags)
+        Matrix of autoregressive coefficients.
+    """
+
+    def __init__(self, lags, L,  K, lambda_f, lambda_x, lambda_w, alpha, eta, max_iter=1000, 
+                 F_step=0.0001, X_step=0.0001, W_step=0.0001
+        ):
+        super(TRMF, self).__init__()
+
+        self.lags = lags
+        self.L = L
+        self.K = K
+        self.lambda_f = lambda_f
+        self.lambda_x = lambda_x
+        self.lambda_w = lambda_w
+        self.alpha = alpha
+        self.eta = eta
+        self.max_iter = max_iter
+        self.F_step = F_step
+        self.X_step = X_step
+        self.W_step = W_step
+        
+        self.W = None
+        self.F = None
+        self.X = None
+
+
+    def fit(self, 
+        train_set: Union[dict, str],
+        val_set: Optional[Union[dict, str]] = None,
+        file_type: str = "hdf5",
+        resume: bool = False,
+    ) -> None:
+        """Fit the TRMF model according to the given training data.
+
+        Model fits through sequential updating three matrices:
+            -   matrix self.F;
+            -   matrix self.X;
+            -   matrix self.W.
+            
+        Each matrix updated with gradient descent.
+
+        Parameters
+        ----------
+        train_set : ndarray, shape (n_timeseries, n_timepoints)
+            Training data.
+
+        val_set : ndarray, shape (n_timeseries, n_timepoints)
+            Validation data.
+
+        file_type :
+            The type of the given file if train_set and val_set are path strings.
+
+        resume : bool
+            Used to continue fitting.
+
+        Returns
+        -------
+        self : object
+            Returns self.
+        """
+
+        if isinstance(train_set, str):
+            with h5py.File(train_set, "r") as f:
+                X = f["X"][:]
+        else:
+            X = train_set["X"]
+
+        assert len(X.shape) == 2, (
+            f"Input X should have 2 dimensions [n_samples, n_features], "
+            f"but the actual shape of X: {X.shape}"
+        )
+        if isinstance(X, list):
+            X = np.asarray(X)
+
+        if not resume:
+            self.Y = X.copy()
+            mask = np.array((~np.isnan(self.Y)).astype(int))
+            self.mask = mask
+            self.Y[self.mask == 0] = 0.
+            self.N, self.T = self.Y.shape
+            self.W = np.random.randn(self.K, self.L) / self.L
+            self.F = np.random.randn(self.N, self.K)
+            self.X = np.random.randn(self.K, self.T)
+
+        for _ in range(self.max_iter):
+            self._update_F(step=self.F_step)
+            self._update_X(step=self.X_step)
+            self._update_W(step=self.W_step)
+
+
+    def impute_missings(self):
+        """Impute each missing element in timeseries.
+
+        Model uses matrix X and F to get all missing elements.
+
+        Parameters
+        ----------
+
+        Returns
+        -------
+        data : ndarray, shape (n_timeseries, T)
+            Predictions.
+        """
+        data = self.Y
+        data[self.mask == 0] = np.dot(self.F, self.X)[self.mask == 0]
+        result_dict = {
+            "imputation": data,
+        }
+        return result_dict
+
+    
+    def impute(
+        self,
+    ) -> np.ndarray:
+        result_dict = self.impute_missings()
+        return result_dict["imputation"]
+
+
+    def _update_F(self, step, n_iter=1):
+        """Gradient descent of matrix F.
+
+        n_iter steps of gradient descent of matrix F.
+
+        Parameters
+        ----------
+        step : float
+            Step of gradient descent when updating matrix.
+
+        n_iter : int
+            Number of gradient steps to be made.
+
+        Returns
+        -------
+        self : objects
+            Returns self.
+        """
+
+        for _ in range(n_iter):
+            self.F -= step * self._grad_F()
+
+
+    def _update_X(self, step, n_iter=1):
+        """Gradient descent of matrix X.
+
+        n_iter steps of gradient descent of matrix X.
+
+        Parameters
+        ----------
+        step : float
+            Step of gradient descent when updating matrix.
+
+        n_iter : int
+            Number of gradient steps to be made.
+
+        Returns
+        -------
+        self : objects
+            Returns self.
+        """
+
+        for _ in range(n_iter):
+            self.X -= step * self._grad_X()
+
+
+    def _update_W(self, step, n_iter=1):
+        """Gradient descent of matrix W.
+
+        n_iter steps of gradient descent of matrix W.
+
+        Parameters
+        ----------
+        step : float
+            Step of gradient descent when updating matrix.
+
+        n_iter : int
+            Number of gradient steps to be made.
+
+        Returns
+        -------
+        self : objects
+            Returns self.
+        """
+
+        for _ in range(n_iter):
+            self.W -= step * self._grad_W()
+
+
+    def _grad_F(self):
+        """Gradient of matrix F.
+
+        Evaluating gradient of matrix F.
+
+        Parameters
+        ----------
+
+        Returns
+        -------
+        self : objects
+            Returns self.
+        """
+
+        return - 2 * np.dot((self.Y - np.dot(self.F, self.X)) * self.mask, self.X.T) + 2 * self.lambda_f * self.F
+
+
+    def _grad_X(self):
+        """Gradient of matrix X.
+
+        Evaluating gradient of matrix X.
+
+        Parameters
+        ----------
+
+        Returns
+        -------
+        self : objects
+            Returns self.
+        """
+
+        for l in range(self.L):
+            lag = self.lags[l]
+            W_l = self.W[:, l].repeat(self.T, axis=0).reshape(self.K, self.T)
+            X_l = self.X * W_l
+            z_1 = self.X - np.roll(X_l, lag, axis=1)
+            z_1[:, :max(self.lags)] = 0.
+            z_2 = - (np.roll(self.X, -lag, axis=1) - X_l) * W_l
+            z_2[:, -lag:] = 0.
+
+        grad_T_x = z_1 + z_2
+        return - 2 * np.dot(self.F.T, self.mask * (self.Y - np.dot(self.F, self.X))) + self.lambda_x * grad_T_x + self.eta * self.X
+
+
+    def _grad_W(self):
+        """Gradient of matrix W.
+
+        Evaluating gradient of matrix W.
+
+        Parameters
+        ----------
+
+        Returns
+        -------
+        self : objects
+            Returns self.
+        """
+
+        grad = np.zeros((self.K, self.L))
+        for l in range(self.L):
+            lag = self.lags[l]
+            W_l = self.W[:, l].repeat(self.T, axis=0).reshape(self.K, self.T)
+            X_l = self.X * W_l
+            z_1 = self.X - np.roll(X_l, lag, axis=1)
+            z_1[:, :max(self.lags)] = 0.
+            z_2 = - (z_1 * np.roll(self.X, lag, axis=1)).sum(axis=1)
+            grad[:, l] = z_2
+        return grad + self.W * 2 * self.lambda_w / self.lambda_x -\
+               self.alpha * 2 * (1 - self.W.sum(axis=1)).repeat(self.L).reshape(self.W.shape)
\ No newline at end of file