Skip to content

Commit

Permalink
Merge pull request #298 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Fixing the issue in time-decay matrix calculation and simplify the code
  • Loading branch information
WenjieDu authored Jan 16, 2024
2 parents 59e75d1 + 68f8f3e commit f01de1d
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 24 deletions.
5 changes: 4 additions & 1 deletion pypots/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,12 @@ def _open_file_handle(self) -> h5py.File:
raise ImportError(
"h5py is missing and cannot be imported. Please install it first."
)
except FileNotFoundError as e:
raise FileNotFoundError(f"{e}")
except OSError as e:
raise TypeError(
f"{e} This probably is caused by file type error. "
f"{e}\n"
f"Check out the above error log. This probably is caused by file type error. "
f"Please confirm that the given file {data_file_path} is an h5 file."
)
except Exception as e:
Expand Down
26 changes: 12 additions & 14 deletions pypots/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,13 @@ def _parse_delta_torch(missing_mask: torch.Tensor) -> torch.Tensor:

def cal_delta_for_single_sample(mask: torch.Tensor) -> torch.Tensor:
"""calculate single sample's delta. The sample's shape is [n_steps, n_features]."""
d = []
for step in range(n_steps):
if step == 0:
d.append(torch.zeros(1, n_features, device=device))
else:
d.append(
torch.ones(1, n_features, device=device) + (1 - mask[step]) * d[-1]
)
# the first step in the delta matrix is all 0
d = [torch.zeros(1, n_features, device=device)]

for step in range(1, n_steps):
d.append(
torch.ones(1, n_features, device=device) + (1 - mask[step - 1]) * d[-1]
)
d = torch.concat(d, dim=0)
return d

Expand Down Expand Up @@ -108,12 +107,11 @@ def _parse_delta_numpy(missing_mask: np.ndarray) -> np.ndarray:

def cal_delta_for_single_sample(mask: np.ndarray) -> np.ndarray:
"""calculate single sample's delta. The sample's shape is [n_steps, n_features]."""
d = []
for step in range(seq_len):
if step == 0:
d.append(np.zeros(n_features))
else:
d.append(np.ones(n_features) + (1 - mask[step]) * d[-1])
# the first step in the delta matrix is all 0
d = [np.zeros(n_features)]

for step in range(1, seq_len):
d.append(np.ones(n_features) + (1 - mask[step - 1]) * d[-1])
d = np.asarray(d)
return d

Expand Down
15 changes: 7 additions & 8 deletions pypots/imputation/mrnn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,13 @@ def mrnn_parse_delta_torch(missing_mask: torch.Tensor) -> torch.Tensor:

def cal_delta_for_single_sample(mask: torch.Tensor) -> torch.Tensor:
"""calculate single sample's delta. The sample's shape is [n_steps, n_features]."""
d = []
for step in range(n_steps):
if step == 0:
d.append(torch.ones(1, n_features, device=device))
else:
d.append(
torch.ones(1, n_features, device=device) + (1 - mask[step]) * d[-1]
)
# the first step in the delta matrix is all 0
d = [torch.ones(1, n_features, device=device)]

for step in range(1, n_steps):
d.append(
torch.ones(1, n_features, device=device) + (1 - mask[step - 1]) * d[-1]
)
d = torch.concat(d, dim=0)
return d

Expand Down
2 changes: 1 addition & 1 deletion pypots/nn/modules/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def forward(

# keep useful variables
batch_size, n_steps = q.size(0), q.size(1)
residual = q
residual = v

# now separate the last dimension of q, k, v into different heads -> [batch_size, n_steps, n_heads, d_k or d_v]
q = self.w_qs(q).view(batch_size, n_steps, self.n_heads, self.d_k)
Expand Down

0 comments on commit f01de1d

Please sign in to comment.