Skip to content

Commit

Permalink
refactor: make FITS able to apply customized loss func;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Sep 27, 2024
1 parent 92773f1 commit f674a3c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 30 deletions.
4 changes: 2 additions & 2 deletions pypots/imputation/fits/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
self.output_projection = nn.Linear(n_features, n_features)
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

if self.apply_nonstationary_norm:
Expand Down Expand Up @@ -75,7 +75,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

# if in training mode, return results with losses
if training:
if self.training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask)
results["ORT_loss"] = ORT_loss
Expand Down
53 changes: 25 additions & 28 deletions pypots/imputation/fits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,11 @@ class FITS(BaseNNImputer):
n_features :
The number of features in the time-series data sample.
n_layers :
The number of layers in the FITS model.
cut_freq :
The cut-off frequency for the Fourier transformation.
d_model :
The dimension of the model.
n_heads :
The number of heads in each layer of FITS.
d_ffn :
The dimension of the feed-forward network.
factor :
The factor of the auto correlation mechanism for the FITS model.
moving_avg_window_size :
The window size of moving average.
dropout :
The dropout rate for the model.
individual :
Whether to use individual Fourier transformation for each feature.
ORT_weight :
The weight for the ORT loss, the same as SAITS.
Expand All @@ -71,6 +56,14 @@ class FITS(BaseNNImputer):
stopped when the model does not perform better after that number of epochs.
Leaving it default as None will disable the early-stopping.
train_loss_func:
The customized loss function designed by users for training the model.
If not given, will use the default loss as claimed in the original paper.
val_metric_func:
The customized metric function designed by users for validating the model.
If not given, will use the default MSE metric.
optimizer :
The optimizer for model training.
If not given, will use a default Adam optimizer.
Expand Down Expand Up @@ -115,6 +108,8 @@ def __init__(
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
train_loss_func: Optional[dict] = None,
val_metric_func: Optional[dict] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand All @@ -123,14 +118,16 @@ def __init__(
verbose: bool = True,
):
super().__init__(
batch_size,
epochs,
patience,
num_workers,
device,
saving_path,
model_saving_strategy,
verbose,
batch_size=batch_size,
epochs=epochs,
patience=patience,
train_loss_func=train_loss_func,
val_metric_func=val_metric_func,
num_workers=num_workers,
device=device,
saving_path=saving_path,
model_saving_strategy=model_saving_strategy,
verbose=verbose,
)

self.n_steps = n_steps
Expand Down Expand Up @@ -272,7 +269,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
imputation_collector.append(results["imputed_data"])

# Step 3: output collection and return
Expand Down

0 comments on commit f674a3c

Please sign in to comment.