Skip to content

Commit

Permalink
feat: enable SAITS to return its latent attention weights with predic…
Browse files Browse the repository at this point in the history
…t();
  • Loading branch information
WenjieDu committed Nov 18, 2023
1 parent 2e5c42b commit f6d4e37
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 17 deletions.
59 changes: 59 additions & 0 deletions pypots/imputation/saits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,37 @@ def predict(
test_set: Union[dict, str],
file_type: str = "h5py",
diagonal_attention_mask: bool = True,
return_latent_vars: bool = False,
) -> dict:
"""Make predictions for the input data with the trained model.
Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type : str
The type of the given file if test_set is a path string.
diagonal_attention_mask : bool
Whether to apply a diagonal attention mask to the self-attention mechanism in the testing stage.
return_latent_vars : bool
Whether to return the latent variables in SAITS, e.g. attention weights of two DMSA blocks and
the weight matrix from the combination block, etc.
Returns
-------
result_dict : dict,
The dictionary containing the clustering results and latent variables if necessary.
"""
# Step 1: wrap the input data with classes Dataset and DataLoader
self.model.eval() # set the model as eval status to freeze it.
test_set = BaseDataset(test_set, return_labels=False, file_type=file_type)
Expand All @@ -306,6 +336,9 @@ def predict(
num_workers=self.num_workers,
)
imputation_collector = []
first_DMSA_attn_weights_collector = []
second_DMSA_attn_weights_collector = []
combining_weights_collector = []

# Step 2: process the data with the model
with torch.no_grad():
Expand All @@ -317,11 +350,37 @@ def predict(
imputed_data = results["imputed_data"]
imputation_collector.append(imputed_data)

if return_latent_vars:
first_DMSA_attn_weights = (
results["first_DMSA_attn_weights"].cpu().numpy()
)
second_DMSA_attn_weights = (
results["second_DMSA_attn_weights"].cpu().numpy()
)
combining_weights = results["combining_weights"].cpu().numpy()

first_DMSA_attn_weights_collector.append(first_DMSA_attn_weights)
second_DMSA_attn_weights_collector.append(second_DMSA_attn_weights)
combining_weights_collector.append(combining_weights)

# Step 3: output collection and return
imputation = torch.cat(imputation_collector).cpu().detach().numpy()
result_dict = {
"imputation": imputation,
}

if return_latent_vars:
latent_var_collector = {
"first_DMSA_attn_weights": np.concatenate(
first_DMSA_attn_weights_collector
),
"second_DMSA_attn_weights": np.concatenate(
second_DMSA_attn_weights_collector
),
"combining_weights": np.concatenate(combining_weights_collector),
}
result_dict["latent_vars"] = latent_var_collector

return result_dict

def impute(
Expand Down
45 changes: 31 additions & 14 deletions pypots/imputation/saits/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _process(
self,
inputs: dict,
diagonal_attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, list]:
) -> Tuple[torch.Tensor, list, list]:
X, masks = inputs["X"], inputs["missing_mask"]

# first DMSA block
Expand All @@ -105,8 +105,11 @@ def _process(
enc_output = self.dropout(
self.position_enc(input_X_for_first)
) # namely, term e in the math equation
first_DMSA_attn_weights = None
for encoder_layer in self.layer_stack_for_first_block:
enc_output, _ = encoder_layer(enc_output, diagonal_attention_mask)
enc_output, first_DMSA_attn_weights = encoder_layer(
enc_output, diagonal_attention_mask
)

X_tilde_1 = self.reduce_dim_z(enc_output)
X_prime = masks * X + (1 - masks) * X_tilde_1
Expand All @@ -117,30 +120,39 @@ def _process(
enc_output = self.position_enc(
input_X_for_second
) # namely term alpha in math algo
attn_weights = None
second_DMSA_attn_weights = None
for encoder_layer in self.layer_stack_for_second_block:
enc_output, attn_weights = encoder_layer(enc_output)
enc_output, second_DMSA_attn_weights = encoder_layer(
enc_output, diagonal_attention_mask
)

X_tilde_2 = self.reduce_dim_gamma(F.relu(self.reduce_dim_beta(enc_output)))

# attention-weighted combine
attn_weights = attn_weights.squeeze(dim=1) # namely term A_hat in Eq.
if len(attn_weights.shape) == 4:
copy_second_DMSA_weights = second_DMSA_attn_weights.clone()
copy_second_DMSA_weights = copy_second_DMSA_weights.squeeze(
dim=1
) # namely term A_hat in Eq.
if len(copy_second_DMSA_weights.shape) == 4:
# if having more than 1 head, then average attention weights from all heads
attn_weights = torch.transpose(attn_weights, 1, 3)
attn_weights = attn_weights.mean(dim=3)
attn_weights = torch.transpose(attn_weights, 1, 2)
copy_second_DMSA_weights = torch.transpose(copy_second_DMSA_weights, 1, 3)
copy_second_DMSA_weights = copy_second_DMSA_weights.mean(dim=3)
copy_second_DMSA_weights = torch.transpose(copy_second_DMSA_weights, 1, 2)

# namely term eta
combining_weights = torch.sigmoid(
self.weight_combine(torch.cat([masks, attn_weights], dim=2))
self.weight_combine(torch.cat([masks, copy_second_DMSA_weights], dim=2))
)
# combine X_tilde_1 and X_tilde_2
X_tilde_3 = (1 - combining_weights) * X_tilde_2 + combining_weights * X_tilde_1
# replace non-missing part with original data
X_c = masks * X + (1 - masks) * X_tilde_3

return X_c, [X_tilde_1, X_tilde_2, X_tilde_3]
return (
X_c,
[X_tilde_1, X_tilde_2, X_tilde_3],
[first_DMSA_attn_weights, second_DMSA_attn_weights, combining_weights],
)

def forward(
self, inputs: dict, diagonal_attention_mask: bool = False, training: bool = True
Expand All @@ -156,9 +168,11 @@ def forward(
else:
diagonal_attention_mask = None

imputed_data, [X_tilde_1, X_tilde_2, X_tilde_3] = self._process(
inputs, diagonal_attention_mask
)
(
imputed_data,
[X_tilde_1, X_tilde_2, X_tilde_3],
[first_DMSA_attn_weights, second_DMSA_attn_weights, combining_weights],
) = self._process(inputs, diagonal_attention_mask)

if not training:
# if not in training mode, return the classification result only
Expand All @@ -180,6 +194,9 @@ def forward(
loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss

results = {
"first_DMSA_attn_weights": first_DMSA_attn_weights,
"second_DMSA_attn_weights": second_DMSA_attn_weights,
"combining_weights": combining_weights,
"imputed_data": imputed_data,
"ORT_loss": ORT_loss,
"MIT_loss": MIT_loss,
Expand Down
12 changes: 9 additions & 3 deletions tests/imputation/saits.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,18 @@ def test_0_fit(self):

@pytest.mark.xdist_group(name="imputation-saits")
def test_1_impute(self):
imputed_X = self.saits.impute(TEST_SET)
imputation_results = self.saits.predict(TEST_SET, return_latent_vars=True)
assert not np.isnan(
imputed_X
imputation_results["imputation"]
).any(), "Output still has missing values after running impute()."
assert (
"latent_vars" in imputation_results.keys()
), "Latent variables are not returned thought `return_latent_vars` is set as True."

test_MAE = cal_mae(
imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"]
imputation_results["imputation"],
DATA["test_X_intact"],
DATA["test_X_indicating_mask"],
)
logger.info(f"SAITS test_MAE: {test_MAE}")

Expand Down

0 comments on commit f6d4e37

Please sign in to comment.