diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index ab13ce7a..c8f99455 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -254,14 +254,42 @@ def _train_model( self._save_log_into_tb_file( training_step, "training", loss_results ) + mean_epoch_train_D_loss = np.mean(epoch_train_loss_D_collector) mean_epoch_train_G_loss = np.mean(epoch_train_loss_G_collector) - logger.info( - f"epoch {epoch}: " - f"training loss_generator {mean_epoch_train_G_loss:.4f}, " - f"train loss_discriminator {mean_epoch_train_D_loss:.4f}" - ) - mean_loss = mean_epoch_train_G_loss + + if val_loader is not None: + self.model.eval() + epoch_val_loss_G_collector = [] + with torch.no_grad(): + for idx, data in enumerate(val_loader): + inputs = self._assemble_input_for_validating(data) + results = self.model.forward(inputs, return_loss=True) + epoch_val_loss_G_collector.append( + results["generation_loss"].sum().item() + ) + mean_val_G_loss = np.mean(epoch_val_loss_G_collector) + # save validating loss logs into the tensorboard file for every epoch if in need + if self.summary_writer is not None: + val_loss_dict = { + "generation_loss": mean_val_G_loss, + } + self._save_log_into_tb_file(epoch, "validating", val_loss_dict) + logger.info( + f"epoch {epoch}: " + f"training loss_generator {mean_epoch_train_G_loss:.4f}, " + f"training loss_discriminator {mean_epoch_train_D_loss:.4f}, " + f"validating loss_generator {mean_val_G_loss:.4f}" + ) + mean_loss = mean_val_G_loss + else: + + logger.info( + f"epoch {epoch}: " + f"training loss_generator {mean_epoch_train_G_loss:.4f}, " + f"training loss_discriminator {mean_epoch_train_D_loss:.4f}" + ) + mean_loss = mean_epoch_train_G_loss if mean_loss < self.best_loss: self.best_loss = mean_loss @@ -314,8 +342,18 @@ def fit( num_workers=self.num_workers, ) + val_loader = None + if val_set is not None: + val_set = DatasetForCRLI(val_set, return_labels=False, file_type=file_type) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + # Step 2: train the model and freeze it - self._train_model(training_loader) + self._train_model(training_loader, val_loader) self.model.load_state_dict(self.best_model_dict) self.model.eval() # set the model as eval status to freeze it. @@ -342,9 +380,9 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - inputs = self.model.forward(inputs, training=False) + inputs = self.model.forward(inputs, return_loss=False) clustering_latent_collector.append(inputs["fcn_latent"]) - imputation_collector.append(inputs["imputation"]) + imputation_collector.append(inputs["imputation_latent"]) imputation = torch.cat(imputation_collector).cpu().detach().numpy() clustering_latent = ( @@ -353,7 +391,7 @@ def predict( clustering = self.model.kmeans.fit_predict(clustering_latent) latent_collector = { "clustering_latent": clustering_latent, - "imputation": imputation, + "imputation_latent": imputation, } result_dict = { diff --git a/pypots/clustering/crli/modules/core.py b/pypots/clustering/crli/modules/core.py index da653cde..cbca6356 100644 --- a/pypots/clustering/crli/modules/core.py +++ b/pypots/clustering/crli/modules/core.py @@ -46,46 +46,36 @@ def __init__( n_clusters=n_clusters, n_init=10, # FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the # value of `n_init` explicitly to suppress the warning. - ) # TODO: implement KMean with torch for gpu acceleration + ) self.n_clusters = n_clusters self.lambda_kmeans = lambda_kmeans self.device = device - def cluster(self, inputs: dict, training_object: str = "generator") -> dict: - # concat final states from generator and input it as the initial state of decoder - imputation, imputed_X, generator_fb_hidden_states = self.generator(inputs) - inputs["imputation"] = imputation - inputs["imputed_X"] = imputed_X - inputs["generator_fb_hidden_states"] = generator_fb_hidden_states - if training_object == "discriminator": - discrimination = self.discriminator(inputs) - inputs["discrimination"] = discrimination - return inputs # if only train discriminator, then no need to run decoder - - reconstruction, fcn_latent = self.decoder(inputs) - inputs["reconstruction"] = reconstruction - inputs["fcn_latent"] = fcn_latent - return inputs - def forward( self, inputs: dict, training_object: str = "generator", - training: bool = True, + return_loss: bool = True, ) -> dict: - assert training_object in [ - "generator", - "discriminator", - ], 'training_object should be "generator" or "discriminator"' - X = inputs["X"] missing_mask = inputs["missing_mask"] batch_size, n_steps, n_features = X.shape losses = {} - inputs = self.cluster(inputs, training_object) - if not training: - # if only run clustering, then no need to calculate loss + + # concat final states from generator and input it as the initial state of decoder + imputation_latent, generator_fb_hidden_states = self.generator(inputs) + inputs["imputation_latent"] = imputation_latent + inputs["generator_fb_hidden_states"] = generator_fb_hidden_states + discrimination = self.discriminator(inputs) + inputs["discrimination"] = discrimination + + reconstruction, fcn_latent = self.decoder(inputs) + inputs["reconstruction"] = reconstruction + inputs["fcn_latent"] = fcn_latent + + # return results directly, skip loss calculation to reduce inference time + if not return_loss: return inputs if training_object == "discriminator": @@ -98,7 +88,7 @@ def forward( l_G = F.binary_cross_entropy_with_logits( inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask ) - l_pre = cal_mse(inputs["imputation"], X, missing_mask) + l_pre = cal_mse(inputs["imputation_latent"], X, missing_mask) l_rec = cal_mse(inputs["reconstruction"], X, missing_mask) HTH = torch.matmul(inputs["fcn_latent"], inputs["fcn_latent"].permute(1, 0)) term_F = torch.nn.init.orthogonal_( diff --git a/pypots/clustering/crli/modules/submodules.py b/pypots/clustering/crli/modules/submodules.py index f6837647..59b155b9 100644 --- a/pypots/clustering/crli/modules/submodules.py +++ b/pypots/clustering/crli/modules/submodules.py @@ -124,18 +124,15 @@ def __init__( self.f_rnn = MultiRNNCell(cell_type, n_layers, n_features, d_hidden, device) self.b_rnn = MultiRNNCell(cell_type, n_layers, n_features, d_hidden, device) - def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]: f_outputs, f_final_hidden_state = self.f_rnn(inputs) b_outputs, b_final_hidden_state = self.b_rnn(inputs) b_outputs = reverse_tensor(b_outputs) # reverse the output of the backward rnn - imputation = (f_outputs + b_outputs) / 2 - imputed_X = inputs["X"] * inputs["missing_mask"] + imputation * ( - 1 - inputs["missing_mask"] - ) + imputation_latent = (f_outputs + b_outputs) / 2 fb_final_hidden_states = torch.concat( [f_final_hidden_state, b_final_hidden_state], dim=-1 ) - return imputation, imputed_X, fb_final_hidden_states + return imputation_latent, fb_final_hidden_states class Discriminator(nn.Module): @@ -161,7 +158,10 @@ def __init__( self.output_layer = nn.Linear(32, d_input) def forward(self, inputs: dict) -> torch.Tensor: - imputed_X = inputs["imputed_X"] + imputed_X = (inputs["X"] * inputs["missing_mask"]) + ( + inputs["imputation_latent"] * (1 - inputs["missing_mask"]) + ) + bz, n_steps, _ = imputed_X.shape hidden_states = [ torch.zeros((bz, 32), device=self.device), diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index 4e31b412..fff13643 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -340,6 +340,7 @@ def _train_model( def fit( self, train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, file_type: str = "h5py", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader @@ -353,8 +354,18 @@ def fit( num_workers=self.num_workers, ) + val_loader = None + if val_set is not None: + val_set = DatasetForVaDER(val_set, return_labels=False, file_type=file_type) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + # Step 2: train the model and freeze it - self._train_model(training_loader) + self._train_model(training_loader, val_loader) self.model.load_state_dict(self.best_model_dict) self.model.eval() # set the model as eval status to freeze it. diff --git a/pypots/clustering/vader/modules/core.py b/pypots/clustering/vader/modules/core.py index 1e33ba54..9f461d80 100644 --- a/pypots/clustering/vader/modules/core.py +++ b/pypots/clustering/vader/modules/core.py @@ -172,7 +172,6 @@ def forward( mu_tilde, stddev_tilde, ) = self.get_results(X, missing_mask) - imputed_X = X_reconstructed * (1 - missing_mask) + X * missing_mask if not training and not pretrain: results = { @@ -182,7 +181,7 @@ def forward( "var": var_c, "phi": phi_c, "z": z, - "imputed_X": imputed_X, + "imputation_latent": X_reconstructed, } # if only run clustering, then no need to calculate loss return results diff --git a/tests/clustering/crli.py b/tests/clustering/crli.py index 2385d1e5..99524753 100644 --- a/tests/clustering/crli.py +++ b/tests/clustering/crli.py @@ -21,6 +21,7 @@ from tests.clustering.config import ( EPOCHS, TRAIN_SET, + VAL_SET, TEST_SET, RESULT_SAVING_DIR_FOR_CLUSTERING, ) @@ -74,9 +75,9 @@ class TestCRLI(unittest.TestCase): @pytest.mark.xdist_group(name="clustering-crli") def test_0_fit(self): logger.info("Training CRLI-GRU...") - self.crli_gru.fit(TRAIN_SET) + self.crli_gru.fit(TRAIN_SET, VAL_SET) logger.info("Training CRLI-LSTM...") - self.crli_lstm.fit(TRAIN_SET) + self.crli_lstm.fit(TRAIN_SET, VAL_SET) @pytest.mark.xdist_group(name="clustering-crli") def test_1_parameters(self): diff --git a/tests/clustering/vader.py b/tests/clustering/vader.py index a76b61e8..42bcda00 100644 --- a/tests/clustering/vader.py +++ b/tests/clustering/vader.py @@ -22,6 +22,7 @@ from tests.clustering.config import ( EPOCHS, TRAIN_SET, + VAL_SET, TEST_SET, RESULT_SAVING_DIR_FOR_CLUSTERING, ) @@ -58,7 +59,7 @@ class TestVaDER(unittest.TestCase): @pytest.mark.xdist_group(name="clustering-vader") def test_0_fit(self): - self.vader.fit(TRAIN_SET) + self.vader.fit(TRAIN_SET, VAL_SET) @pytest.mark.xdist_group(name="clustering-vader") def test_1_cluster(self):