diff --git a/pypots/imputation/mrnn/data.py b/pypots/imputation/mrnn/data.py index e92d3c6e..b8fcdbf4 100644 --- a/pypots/imputation/mrnn/data.py +++ b/pypots/imputation/mrnn/data.py @@ -11,54 +11,7 @@ from pygrinder import fill_and_get_mask_torch from ...data.base import BaseDataset - - -def mrnn_parse_delta_torch(missing_mask: torch.Tensor) -> torch.Tensor: - """Generate the time-gap matrix from the missing mask, this implementation is the same with the MRNN official - implementation in tensorflow https://github.com/jsyoon0823/MRNN, but that is different from the description in the - MRNN paper which is the same with the one from GRUD. - - In PyPOTS team's experiments, we find that this implementation is important to the training stability and - the performance of MRNN, we think this is mainly because this version make the first step of deltas start from 1, - rather than from 0 in the original description. - - Parameters - ---------- - missing_mask : shape of [n_steps, n_features] or [n_samples, n_steps, n_features] - Binary masks indicate missing data (0 means missing values, 1 means observed values). - - Returns - ------- - delta : - The delta matrix indicates the time gaps between observed values. - With the same shape of missing_mask. - """ - - def cal_delta_for_single_sample(mask: torch.Tensor) -> torch.Tensor: - """calculate single sample's delta. The sample's shape is [n_steps, n_features].""" - # the first step in the delta matrix is all 0 - d = [torch.ones(1, n_features, device=device)] - - for step in range(1, n_steps): - d.append( - torch.ones(1, n_features, device=device) + (1 - mask[step - 1]) * d[-1] - ) - d = torch.concat(d, dim=0) - return d - - device = missing_mask.device - if len(missing_mask.shape) == 2: - n_steps, n_features = missing_mask.shape - delta = cal_delta_for_single_sample(missing_mask) - else: - n_samples, n_steps, n_features = missing_mask.shape - delta_collector = [] - for m_mask in missing_mask: - delta = cal_delta_for_single_sample(m_mask) - delta_collector.append(delta.unsqueeze(0)) - delta = torch.concat(delta_collector, dim=0) - - return delta +from ...data.utils import _parse_delta_torch class DatasetForMRNN(BaseDataset): @@ -105,10 +58,10 @@ def __init__( forward_missing_mask = self.missing_mask forward_X = self.X - forward_delta = mrnn_parse_delta_torch(forward_missing_mask) + forward_delta = _parse_delta_torch(forward_missing_mask) backward_X = torch.flip(forward_X, dims=[1]) backward_missing_mask = torch.flip(forward_missing_mask, dims=[1]) - backward_delta = mrnn_parse_delta_torch(backward_missing_mask) + backward_delta = _parse_delta_torch(backward_missing_mask) self.processed_data = { "forward": { @@ -195,14 +148,14 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: forward = { "X": X, "missing_mask": missing_mask, - "deltas": mrnn_parse_delta_torch(missing_mask), + "deltas": _parse_delta_torch(missing_mask), } backward = { "X": torch.flip(forward["X"], dims=[0]), "missing_mask": torch.flip(forward["missing_mask"], dims=[0]), } - backward["deltas"] = mrnn_parse_delta_torch(backward["missing_mask"]) + backward["deltas"] = _parse_delta_torch(backward["missing_mask"]) sample = [ torch.tensor(idx),