diff --git a/fuse/dl/lightning/pl_module.py b/fuse/dl/lightning/pl_module.py index ee44e8aa..f5b438a2 100644 --- a/fuse/dl/lightning/pl_module.py +++ b/fuse/dl/lightning/pl_module.py @@ -44,7 +44,12 @@ def __init__( List[Tuple[str, OrderedDict[str, MetricBase]]], ] ] = None, - test_metrics: Optional[OrderedDict[str, MetricBase]] = None, + test_metrics: Optional[ + Union[ + OrderedDict[str, MetricBase], + List[Tuple[str, OrderedDict[str, MetricBase]]], + ] + ] = None, optimizers_and_lr_schs: Any = None, callbacks: Optional[Sequence[pl.Callback]] = None, best_epoch_source: Optional[Union[Dict, List[Dict]]] = None, @@ -67,7 +72,9 @@ def __init__( :param validation_metrics: ordereddict of FuseMedML style metrics - used for validation set (must be different instances of metrics (from train_metrics!) In case of multiple validation dataloaders, validation_metrics should be list of tuples (that keeps the same dataloaders list order), Each tuple built from validation dataloader name and corresponding metrics dict. - :param test_metrics: dict of FuseMedML style metrics - used for test set (must be different instances of metrics (from train_metrics and validation_metrics!) + :param test_metrics: ordereddict of FuseMedML style metrics - used for test set (must be different instances of metrics (from train_metrics) + In case of multiple test dataloaders, test_metrics should be list of tuples (that keeps the same dataloaders list order), + Each tuple built from test dataloader name and corresponding metrics dict. :param optimizers_and_lr_schs: see pl.LightningModule.configure_optimizers return value for all options :param callbacks: see pl.LightningModule.configure_callbacks return value for details :param best_epoch_source: Create list of pl.callbacks that saves checkpoints using (pl.callbacks.ModelCheckpoint) and print per epoch summary (fuse.dl.lightning.pl_epoch_summary.ModelEpochSummary). @@ -138,12 +145,14 @@ def __init__( self._validation_metrics = ( validation_metrics if validation_metrics is not None else {} ) - + if test_metrics is None: + self._test_metrics = self._validation_metrics + else: + self._test_metrics = test_metrics # convert all use-cases to the same format that supports multiple val dataloaders: List[Tuple[str, OrderedDict[str, MetricBase]]] if isinstance(self._validation_metrics, dict): self._validation_metrics = [(None, self._validation_metrics)] - self._test_metrics = test_metrics if test_metrics is not None else {} if log_unit not in [None, "optimizer_step", "epoch"]: raise Exception(f"Error: unexpected log_unit {log_unit}") @@ -158,11 +167,12 @@ def __init__( self._prediction_keys = None self._sep = tensorboard_sep + self._training_step_outputs = [] + self._validation_step_outputs = { i: [] for i, _ in enumerate(self._validation_metrics) } - self._training_step_outputs = [] - self._test_step_outputs = [] + self._test_step_outputs = {i: [] for i, _ in enumerate(self._test_metrics)} ## forward def forward(self, batch_dict: NDict) -> NDict: @@ -204,17 +214,27 @@ def validation_step( {"losses": batch_dict["losses"]} ) - def test_step(self, batch_dict: NDict, batch_idx: int) -> None: + def test_step( + self, batch_dict: NDict, batch_idx: int, dataloader_idx: int = 0 + ) -> None: # add step number to batch_dict batch_dict["global_step"] = self.global_step # run forward function and store the outputs in batch_dict["model"] batch_dict = self.forward(batch_dict) # given the batch_dict and FuseMedML style losses - compute the losses, return the total loss (ignored) and save losses values in batch_dict["losses"] - _ = step_losses(self._losses, batch_dict) + if self._validation_losses is not None: + losses = self._validation_losses[dataloader_idx][1] + else: + losses = self._losses + + _ = step_losses(losses, batch_dict) # given the batch_dict and FuseMedML style metrics - collect the required values to compute the metrics on epoch_end - step_metrics(self._test_metrics, batch_dict) + step_metrics(self._test_metrics[dataloader_idx][1], batch_dict) # aggregate losses - self._test_step_outputs.append({"losses": batch_dict["losses"]}) + if losses: # if there are losses, collect the results + self._test_step_outputs[dataloader_idx].append( + {"losses": batch_dict["losses"]} + ) def predict_step(self, batch_dict: NDict, batch_idx: int) -> dict: if self._prediction_keys is None: @@ -267,20 +287,25 @@ def on_validation_epoch_end(self) -> None: } def on_test_epoch_end(self) -> None: - step_outputs = self._test_step_outputs + step_outputs_lst = self._test_step_outputs # for the logs to be at each epoch, not each step if self._log_unit == "epoch": - self.log("step", self.current_epoch, on_epoch=True, sync_dist=True) - # calc average epoch loss and log it - epoch_end_compute_and_log_losses( - self, "test", [e["losses"] for e in step_outputs], sep=self._sep - ) - # evaluate and log it - epoch_end_compute_and_log_metrics( - self, "test", self._test_metrics, sep=self._sep - ) + self.log("step", float(self.current_epoch), on_epoch=True, sync_dist=True) + for dataloader_idx, step_outputs in step_outputs_lst.items(): + if len(self._test_metrics) == 1: + prefix = "test" + else: + prefix = f"test.{self._test_metrics[dataloader_idx][0]}" + # calc average epoch loss and log it + epoch_end_compute_and_log_losses( + self, prefix, [e["losses"] for e in step_outputs], sep=self._sep + ) + # evaluate and log it + epoch_end_compute_and_log_metrics( + self, prefix, self._test_metrics[dataloader_idx][1], sep=self._sep + ) # reset state - self._test_step_outputs.clear() + self._test_step_outputs = {i: [] for i, _ in enumerate(self._test_metrics)} # configuration def configure_callbacks(self) -> Sequence[pl.Callback]: