From f674a3c25ea4d7affe251554dbe896b8a66d79eb Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 27 Sep 2024 14:30:49 +0800 Subject: [PATCH] refactor: make FITS able to apply customized loss func; --- pypots/imputation/fits/core.py | 4 +-- pypots/imputation/fits/model.py | 53 ++++++++++++++++----------------- 2 files changed, 27 insertions(+), 30 deletions(-) diff --git a/pypots/imputation/fits/core.py b/pypots/imputation/fits/core.py index 701ec4ca..ba3f2661 100644 --- a/pypots/imputation/fits/core.py +++ b/pypots/imputation/fits/core.py @@ -46,7 +46,7 @@ def __init__( self.output_projection = nn.Linear(n_features, n_features) self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] if self.apply_nonstationary_norm: @@ -75,7 +75,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss diff --git a/pypots/imputation/fits/model.py b/pypots/imputation/fits/model.py index 2664da26..b5c9dc7a 100644 --- a/pypots/imputation/fits/model.py +++ b/pypots/imputation/fits/model.py @@ -33,26 +33,11 @@ class FITS(BaseNNImputer): n_features : The number of features in the time-series data sample. - n_layers : - The number of layers in the FITS model. + cut_freq : + The cut-off frequency for the Fourier transformation. - d_model : - The dimension of the model. - - n_heads : - The number of heads in each layer of FITS. - - d_ffn : - The dimension of the feed-forward network. - - factor : - The factor of the auto correlation mechanism for the FITS model. - - moving_avg_window_size : - The window size of moving average. - - dropout : - The dropout rate for the model. + individual : + Whether to use individual Fourier transformation for each feature. ORT_weight : The weight for the ORT loss, the same as SAITS. @@ -71,6 +56,14 @@ class FITS(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -115,6 +108,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: int = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -123,14 +118,16 @@ def __init__( verbose: bool = True, ): super().__init__( - batch_size, - epochs, - patience, - num_workers, - device, - saving_path, - model_saving_strategy, - verbose, + batch_size=batch_size, + epochs=epochs, + patience=patience, + train_loss_func=train_loss_func, + val_metric_func=val_metric_func, + num_workers=num_workers, + device=device, + saving_path=saving_path, + model_saving_strategy=model_saving_strategy, + verbose=verbose, ) self.n_steps = n_steps @@ -272,7 +269,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return