Skip to content

Commit

Permalink
Merge pull request #300 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Roll back the delta calculation of M-RNN to the same with GRU-D
  • Loading branch information
WenjieDu authored Jan 17, 2024
2 parents f01de1d + 5e5b7db commit 65a4ced
Showing 1 changed file with 5 additions and 52 deletions.
57 changes: 5 additions & 52 deletions pypots/imputation/mrnn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,54 +11,7 @@
from pygrinder import fill_and_get_mask_torch

from ...data.base import BaseDataset


def mrnn_parse_delta_torch(missing_mask: torch.Tensor) -> torch.Tensor:
"""Generate the time-gap matrix from the missing mask, this implementation is the same with the MRNN official
implementation in tensorflow https://github.com/jsyoon0823/MRNN, but that is different from the description in the
MRNN paper which is the same with the one from GRUD.
In PyPOTS team's experiments, we find that this implementation is important to the training stability and
the performance of MRNN, we think this is mainly because this version make the first step of deltas start from 1,
rather than from 0 in the original description.
Parameters
----------
missing_mask : shape of [n_steps, n_features] or [n_samples, n_steps, n_features]
Binary masks indicate missing data (0 means missing values, 1 means observed values).
Returns
-------
delta :
The delta matrix indicates the time gaps between observed values.
With the same shape of missing_mask.
"""

def cal_delta_for_single_sample(mask: torch.Tensor) -> torch.Tensor:
"""calculate single sample's delta. The sample's shape is [n_steps, n_features]."""
# the first step in the delta matrix is all 0
d = [torch.ones(1, n_features, device=device)]

for step in range(1, n_steps):
d.append(
torch.ones(1, n_features, device=device) + (1 - mask[step - 1]) * d[-1]
)
d = torch.concat(d, dim=0)
return d

device = missing_mask.device
if len(missing_mask.shape) == 2:
n_steps, n_features = missing_mask.shape
delta = cal_delta_for_single_sample(missing_mask)
else:
n_samples, n_steps, n_features = missing_mask.shape
delta_collector = []
for m_mask in missing_mask:
delta = cal_delta_for_single_sample(m_mask)
delta_collector.append(delta.unsqueeze(0))
delta = torch.concat(delta_collector, dim=0)

return delta
from ...data.utils import _parse_delta_torch


class DatasetForMRNN(BaseDataset):
Expand Down Expand Up @@ -105,10 +58,10 @@ def __init__(
forward_missing_mask = self.missing_mask
forward_X = self.X

forward_delta = mrnn_parse_delta_torch(forward_missing_mask)
forward_delta = _parse_delta_torch(forward_missing_mask)
backward_X = torch.flip(forward_X, dims=[1])
backward_missing_mask = torch.flip(forward_missing_mask, dims=[1])
backward_delta = mrnn_parse_delta_torch(backward_missing_mask)
backward_delta = _parse_delta_torch(backward_missing_mask)

self.processed_data = {
"forward": {
Expand Down Expand Up @@ -195,14 +148,14 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
forward = {
"X": X,
"missing_mask": missing_mask,
"deltas": mrnn_parse_delta_torch(missing_mask),
"deltas": _parse_delta_torch(missing_mask),
}

backward = {
"X": torch.flip(forward["X"], dims=[0]),
"missing_mask": torch.flip(forward["missing_mask"], dims=[0]),
}
backward["deltas"] = mrnn_parse_delta_torch(backward["missing_mask"])
backward["deltas"] = _parse_delta_torch(backward["missing_mask"])

sample = [
torch.tensor(idx),
Expand Down

0 comments on commit 65a4ced

Please sign in to comment.