-
-
Notifications
You must be signed in to change notification settings - Fork 120
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #344 from WenjieDu/dev
Release v0.4, apply SAITS embedding strategy to the newly added models, and update README
- Loading branch information
Showing
14 changed files
with
227 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
# Created by Wenjie Du <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from .submodules import ( | ||
|
@@ -38,7 +39,7 @@ def __init__( | |
self.seq_len = n_steps | ||
self.n_layers = n_layers | ||
self.enc_embedding = DataEmbedding( | ||
n_features, | ||
n_features * 2, | ||
d_model, | ||
dropout=dropout, | ||
with_pos=False, | ||
|
@@ -63,28 +64,35 @@ def __init__( | |
) | ||
|
||
# for the imputation task, the output dim is the same as input dim | ||
self.projection = nn.Linear(d_model, n_features) | ||
self.output_projection = nn.Linear(d_model, n_features) | ||
|
||
def forward(self, inputs: dict, training: bool = True) -> dict: | ||
X, masks = inputs["X"], inputs["missing_mask"] | ||
|
||
# embedding | ||
enc_out = self.enc_embedding(X) # [B,T,C] | ||
# WDU: the original Autoformer paper isn't proposed for imputation task. Hence the model doesn't take | ||
# the missing mask into account, which means, in the process, the model doesn't know which part of | ||
# the input data is missing, and this may hurt the model's imputation performance. Therefore, I add the | ||
# embedding layers to project the concatenation of features and masks into a hidden space, as well as | ||
# the output layers to project back from the hidden space to the original space. | ||
|
||
# the same as SAITS, concatenate the time series data and the missing mask for embedding | ||
input_X = torch.cat([X, masks], dim=2) | ||
enc_out = self.enc_embedding(input_X) | ||
|
||
# Autoformer encoder processing | ||
enc_out, attns = self.encoder(enc_out) | ||
|
||
# project back the original data space | ||
dec_out = self.projection(enc_out) | ||
output = self.output_projection(enc_out) | ||
|
||
imputed_data = masks * X + (1 - masks) * dec_out | ||
imputed_data = masks * X + (1 - masks) * output | ||
results = { | ||
"imputed_data": imputed_data, | ||
} | ||
|
||
if training: | ||
# `loss` is always the item for backward propagating to update the model | ||
loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"]) | ||
loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) | ||
results["loss"] = loss | ||
|
||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.